mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-04 03:25:30 +08:00
Compare commits
46 Commits
v0.22.0
...
0db00f70b2
| Author | SHA1 | Date | |
|---|---|---|---|
| 0db00f70b2 | |||
| 701761d119 | |||
| 2993fc666b | |||
| 8a6d205df0 | |||
| 912b6b023e | |||
| 89e8818dda | |||
| 1dba6b5bf9 | |||
| 3fcf2ee54c | |||
| d8f413a885 | |||
| 7264fb6978 | |||
| bd4bc57009 | |||
| 0569b50fed | |||
| 6b64641042 | |||
| 9cef3a2625 | |||
| e7e89d3ecb | |||
| 13e212c856 | |||
| 61cf430dbb | |||
| e841b09d63 | |||
| b1a1eedf53 | |||
| 68e3b33ae4 | |||
| cd55f6c1b8 | |||
| 996b5fe14e | |||
| db4fd19c82 | |||
| 12db62b9c7 | |||
| b5f2cf16bc | |||
| e27ff8d3d4 | |||
| 5f59418aba | |||
| 87e69868c0 | |||
| 72c20022f6 | |||
| 3f2472f1b9 | |||
| 1d4d67daf8 | |||
| 7538e218a5 | |||
| 6b52f7df5a | |||
| 63131ec9b2 | |||
| e8f1a245a6 | |||
| 908450509f | |||
| 70a0f081f6 | |||
| 93422fa8cc | |||
| bfc84ba95b | |||
| 871055b0fc | |||
| ba71160b14 | |||
| bd5dda6b10 | |||
| 774563970b | |||
| 83d84e90ed | |||
| 8ef2f79d0a | |||
| 296476ab89 |
32
.github/workflows/tests.yml
vendored
32
.github/workflows/tests.yml
vendored
@ -95,6 +95,38 @@ jobs:
|
||||
version: ">=0.11.x"
|
||||
args: "check"
|
||||
|
||||
- name: Check comments of changed Python files
|
||||
if: ${{ false }}
|
||||
run: |
|
||||
if [[ ${{ github.event_name }} == 'pull_request_target' ]]; then
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }} \
|
||||
| grep -E '\.(py)$' || true)
|
||||
|
||||
if [ -n "$CHANGED_FILES" ]; then
|
||||
echo "Check comments of changed Python files with check_comment_ascii.py"
|
||||
|
||||
readarray -t files <<< "$CHANGED_FILES"
|
||||
HAS_ERROR=0
|
||||
|
||||
for file in "${files[@]}"; do
|
||||
if [ -f "$file" ]; then
|
||||
if python3 check_comment_ascii.py "$file"; then
|
||||
echo "✅ $file"
|
||||
else
|
||||
echo "❌ $file"
|
||||
HAS_ERROR=1
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $HAS_ERROR -ne 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "No Python files changed"
|
||||
fi
|
||||
fi
|
||||
|
||||
- name: Build ragflow:nightly
|
||||
run: |
|
||||
RUNNER_WORKSPACE_PREFIX=${RUNNER_WORKSPACE_PREFIX:-${HOME}}
|
||||
|
||||
@ -51,7 +51,9 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
||||
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
||||
apt install -y libjemalloc-dev && \
|
||||
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
||||
apt install -y ghostscript
|
||||
apt install -y ghostscript && \
|
||||
apt install -y pandoc && \
|
||||
apt install -y texlive
|
||||
|
||||
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
||||
|
||||
@ -192,9 +192,10 @@ releases! 🌟
|
||||
|
||||
```bash
|
||||
$ cd ragflow/docker
|
||||
|
||||
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
|
||||
|
||||
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
|
||||
@ -192,6 +192,7 @@ Coba demo kami di [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
$ cd ragflow/docker
|
||||
|
||||
# Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases), contoh: git checkout v0.22.0
|
||||
# This steps ensures the **entrypoint.sh** file in the code matches the Docker image version.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
@ -172,6 +172,7 @@
|
||||
$ cd ragflow/docker
|
||||
|
||||
# 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases) 例: git checkout v0.22.0
|
||||
# この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
@ -174,6 +174,7 @@
|
||||
$ cd ragflow/docker
|
||||
|
||||
# Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases), e.g.: git checkout v0.22.0
|
||||
# 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
@ -192,6 +192,7 @@ Experimente nossa demo em [https://demo.ragflow.io](https://demo.ragflow.io).
|
||||
$ cd ragflow/docker
|
||||
|
||||
# Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases), ex.: git checkout v0.22.0
|
||||
# Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker.
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
@ -191,6 +191,7 @@
|
||||
$ cd ragflow/docker
|
||||
|
||||
# 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases),例:git checkout v0.22.0
|
||||
# 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
@ -192,6 +192,7 @@
|
||||
$ cd ragflow/docker
|
||||
|
||||
# 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases),例如:git checkout v0.22.0
|
||||
# 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。
|
||||
|
||||
# Use CPU for DeepDoc tasks:
|
||||
$ docker compose -f docker-compose.yml up -d
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
Admin Service is a dedicated management component designed to monitor, maintain, and administrate the RAGFlow system. It provides comprehensive tools for ensuring system stability, performing operational tasks, and managing users and permissions efficiently.
|
||||
|
||||
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
|
||||
The service offers real-time monitoring of critical components, including the RAGFlow server, Task Executor processes, and dependent services such as MySQL, Infinity, Elasticsearch, Redis, and MinIO. It automatically checks their health status, resource usage, and uptime, and performs restarts in case of failures to minimize downtime.
|
||||
|
||||
For user and system management, it supports listing, creating, modifying, and deleting users and their associated resources like knowledge bases and Agents.
|
||||
|
||||
|
||||
@ -378,7 +378,7 @@ class AdminCLI(Cmd):
|
||||
self.session.headers.update({
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': response.headers['Authorization'],
|
||||
'User-Agent': 'RAGFlow-CLI/0.22.0'
|
||||
'User-Agent': 'RAGFlow-CLI/0.22.1'
|
||||
})
|
||||
print("Authentication successful.")
|
||||
return True
|
||||
@ -393,7 +393,9 @@ class AdminCLI(Cmd):
|
||||
print(f"Can't access {self.host}, port: {self.port}")
|
||||
|
||||
def _format_service_detail_table(self, data):
|
||||
if not any([isinstance(v, list) for v in data.values()]):
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
if not all([isinstance(v, list) for v in data.values()]):
|
||||
# normal table
|
||||
return data
|
||||
# handle task_executor heartbeats map, for example {'name': [{'done': 2, 'now': timestamp1}, {'done': 3, 'now': timestamp2}]
|
||||
@ -404,7 +406,7 @@ class AdminCLI(Cmd):
|
||||
task_executor_list.append({
|
||||
"task_executor_name": k,
|
||||
**heartbeats[0],
|
||||
})
|
||||
} if heartbeats else {"task_executor_name": k})
|
||||
return task_executor_list
|
||||
|
||||
def _print_table_simple(self, data):
|
||||
@ -415,7 +417,8 @@ class AdminCLI(Cmd):
|
||||
# handle single row data
|
||||
data = [data]
|
||||
|
||||
columns = list(data[0].keys())
|
||||
columns = list(set().union(*(d.keys() for d in data)))
|
||||
columns.sort()
|
||||
col_widths = {}
|
||||
|
||||
def get_string_width(text):
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ragflow-cli"
|
||||
version = "0.22.0"
|
||||
version = "0.22.1"
|
||||
description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. "
|
||||
authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }]
|
||||
license = { text = "Apache License, Version 2.0" }
|
||||
|
||||
@ -169,7 +169,7 @@ def login_verify(f):
|
||||
username = auth.parameters['username']
|
||||
password = auth.parameters['password']
|
||||
try:
|
||||
if check_admin(username, password) is False:
|
||||
if not check_admin(username, password):
|
||||
return jsonify({
|
||||
"code": 500,
|
||||
"message": "Access denied",
|
||||
|
||||
@ -25,8 +25,21 @@ from common.config_utils import read_config
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
host: str
|
||||
port: int
|
||||
service_type: str
|
||||
detail_func_name: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port,
|
||||
'service_type': self.service_type}
|
||||
|
||||
|
||||
class ServiceConfigs:
|
||||
configs = dict
|
||||
configs = list[BaseConfig]
|
||||
|
||||
def __init__(self):
|
||||
self.configs = []
|
||||
@ -45,19 +58,6 @@ class ServiceType(Enum):
|
||||
FILE_STORE = "file_store"
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
host: str
|
||||
port: int
|
||||
service_type: str
|
||||
detail_func_name: str
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {'id': self.id, 'name': self.name, 'host': self.host, 'port': self.port,
|
||||
'service_type': self.service_type}
|
||||
|
||||
|
||||
class MetaConfig(BaseConfig):
|
||||
meta_type: str
|
||||
|
||||
@ -227,7 +227,7 @@ def load_configurations(config_path: str) -> list[BaseConfig]:
|
||||
ragflow_count = 0
|
||||
id_count = 0
|
||||
for k, v in raw_configs.items():
|
||||
match (k):
|
||||
match k:
|
||||
case "ragflow":
|
||||
name: str = f'ragflow_{ragflow_count}'
|
||||
host: str = v['host']
|
||||
|
||||
@ -13,8 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import logging
|
||||
import re
|
||||
from werkzeug.security import check_password_hash
|
||||
from common.constants import ActiveEnum
|
||||
@ -190,7 +189,8 @@ class ServiceMgr:
|
||||
config_dict['status'] = service_detail['status']
|
||||
else:
|
||||
config_dict['status'] = 'timeout'
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logging.warning(f"Can't get service details, error: {e}")
|
||||
config_dict['status'] = 'timeout'
|
||||
if not config_dict['host']:
|
||||
config_dict['host'] = '-'
|
||||
@ -205,17 +205,13 @@ class ServiceMgr:
|
||||
|
||||
@staticmethod
|
||||
def get_service_details(service_id: int):
|
||||
service_id = int(service_id)
|
||||
service_idx = int(service_id)
|
||||
configs = SERVICE_CONFIGS.configs
|
||||
service_config_mapping = {
|
||||
c.id: {
|
||||
'name': c.name,
|
||||
'detail_func_name': c.detail_func_name
|
||||
} for c in configs
|
||||
}
|
||||
service_info = service_config_mapping.get(service_id, {})
|
||||
if not service_info:
|
||||
raise AdminException(f"invalid service_id: {service_id}")
|
||||
if service_idx < 0 or service_idx >= len(configs):
|
||||
raise AdminException(f"invalid service_index: {service_idx}")
|
||||
|
||||
service_config = configs[service_idx]
|
||||
service_info = {'name': service_config.name, 'detail_func_name': service_config.detail_func_name}
|
||||
|
||||
detail_func = getattr(health_utils, service_info.get('detail_func_name'))
|
||||
res = detail_func()
|
||||
|
||||
@ -298,8 +298,6 @@ class Canvas(Graph):
|
||||
for kk, vv in kwargs["webhook_payload"].items():
|
||||
self.components[k]["obj"].set_output(kk, vv)
|
||||
|
||||
self.components[k]["obj"].reset(True)
|
||||
|
||||
for k in kwargs.keys():
|
||||
if k in ["query", "user_id", "files"] and kwargs[k]:
|
||||
if k == "files":
|
||||
@ -408,6 +406,10 @@ class Canvas(Graph):
|
||||
else:
|
||||
yield decorate("message", {"content": cpn_obj.output("content")})
|
||||
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
|
||||
|
||||
if isinstance(cpn_obj.output("attachment"), tuple):
|
||||
yield decorate("message", {"attachment": cpn_obj.output("attachment")})
|
||||
|
||||
yield decorate("message_end", {"reference": self.get_reference() if cite else None})
|
||||
|
||||
while partials:
|
||||
|
||||
@ -30,7 +30,7 @@ from api.db.services.mcp_server_service import MCPServerService
|
||||
from common.connection_utils import timeout
|
||||
from rag.prompts.generator import next_step, COMPLETE_TASK, analyze_task, \
|
||||
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question, message_fit_in
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||
from agent.component.llm import LLMParam, LLM
|
||||
|
||||
|
||||
@ -368,11 +368,19 @@ Respond immediately with your final comprehensive answer.
|
||||
|
||||
return "Error occurred."
|
||||
|
||||
def reset(self, temp=False):
|
||||
def reset(self, only_output=False):
|
||||
"""
|
||||
Reset all tools if they have a reset method. This avoids errors for tools like MCPToolCallSession.
|
||||
"""
|
||||
for k in self._param.outputs.keys():
|
||||
self._param.outputs[k]["value"] = None
|
||||
|
||||
for k, cpn in self.tools.items():
|
||||
if hasattr(cpn, "reset") and callable(cpn.reset):
|
||||
cpn.reset()
|
||||
if only_output:
|
||||
return
|
||||
for k in self._param.inputs.keys():
|
||||
self._param.inputs[k]["value"] = None
|
||||
self._param.debug_inputs = {}
|
||||
|
||||
|
||||
@ -463,12 +463,15 @@ class ComponentBase(ABC):
|
||||
return self._param.outputs.get("_ERROR", {}).get("value")
|
||||
|
||||
def reset(self, only_output=False):
|
||||
for k in self._param.outputs.keys():
|
||||
self._param.outputs[k]["value"] = None
|
||||
outputs: dict = self._param.outputs # for better performance
|
||||
for k in outputs.keys():
|
||||
outputs[k]["value"] = None
|
||||
if only_output:
|
||||
return
|
||||
for k in self._param.inputs.keys():
|
||||
self._param.inputs[k]["value"] = None
|
||||
|
||||
inputs: dict = self._param.inputs # for better performance
|
||||
for k in inputs.keys():
|
||||
inputs[k]["value"] = None
|
||||
self._param.debug_inputs = {}
|
||||
|
||||
def get_input(self, key: str=None) -> Union[Any, dict[str, Any]]:
|
||||
|
||||
166
agent/component/list_operations.py
Normal file
166
agent/component/list_operations.py
Normal file
@ -0,0 +1,166 @@
|
||||
from abc import ABC
|
||||
import os
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
|
||||
class ListOperationsParam(ComponentParamBase):
|
||||
"""
|
||||
Define the List Operations component parameters.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.query = ""
|
||||
self.operations = "topN"
|
||||
self.n=0
|
||||
self.sort_method = "asc"
|
||||
self.filter = {
|
||||
"operator": "=",
|
||||
"value": ""
|
||||
}
|
||||
self.outputs = {
|
||||
"result": {
|
||||
"value": [],
|
||||
"type": "Array of ?"
|
||||
},
|
||||
"first": {
|
||||
"value": "",
|
||||
"type": "?"
|
||||
},
|
||||
"last": {
|
||||
"value": "",
|
||||
"type": "?"
|
||||
}
|
||||
}
|
||||
|
||||
def check(self):
|
||||
self.check_empty(self.query, "query")
|
||||
self.check_valid_value(self.operations, "Support operations", ["topN","head","tail","filter","sort","drop_duplicates"])
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {}
|
||||
|
||||
|
||||
class ListOperations(ComponentBase,ABC):
|
||||
component_name = "ListOperations"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||
def _invoke(self, **kwargs):
|
||||
self.input_objects=[]
|
||||
inputs = getattr(self._param, "query", None)
|
||||
self.inputs=self._canvas.get_variable_value(inputs)
|
||||
self.set_input_value(inputs, self.inputs)
|
||||
if self._param.operations == "topN":
|
||||
self._topN()
|
||||
elif self._param.operations == "head":
|
||||
self._head()
|
||||
elif self._param.operations == "tail":
|
||||
self._tail()
|
||||
elif self._param.operations == "filter":
|
||||
self._filter()
|
||||
elif self._param.operations == "sort":
|
||||
self._sort()
|
||||
elif self._param.operations == "drop_duplicates":
|
||||
self._drop_duplicates()
|
||||
|
||||
|
||||
def _coerce_n(self):
|
||||
try:
|
||||
return int(getattr(self._param, "n", 0))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _set_outputs(self, outputs):
|
||||
self._param.outputs["result"]["value"] = outputs
|
||||
self._param.outputs["first"]["value"] = outputs[0] if outputs else None
|
||||
self._param.outputs["last"]["value"] = outputs[-1] if outputs else None
|
||||
|
||||
def _topN(self):
|
||||
n = self._coerce_n()
|
||||
if n < 1:
|
||||
outputs = []
|
||||
else:
|
||||
n = min(n, len(self.inputs))
|
||||
outputs = self.inputs[:n]
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _head(self):
|
||||
n = self._coerce_n()
|
||||
if 1 <= n <= len(self.inputs):
|
||||
outputs = [self.inputs[n - 1]]
|
||||
else:
|
||||
outputs = []
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _tail(self):
|
||||
n = self._coerce_n()
|
||||
if 1 <= n <= len(self.inputs):
|
||||
outputs = [self.inputs[-n]]
|
||||
else:
|
||||
outputs = []
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _filter(self):
|
||||
self._set_outputs([i for i in self.inputs if self._eval(self._norm(i),self._param.filter["operator"],self._param.filter["value"])])
|
||||
|
||||
def _norm(self,v):
|
||||
s = "" if v is None else str(v)
|
||||
return s
|
||||
|
||||
def _eval(self, v, operator, value):
|
||||
if operator == "=":
|
||||
return v == value
|
||||
elif operator == "≠":
|
||||
return v != value
|
||||
elif operator == "contains":
|
||||
return value in v
|
||||
elif operator == "start with":
|
||||
return v.startswith(value)
|
||||
elif operator == "end with":
|
||||
return v.endswith(value)
|
||||
else:
|
||||
return False
|
||||
|
||||
def _sort(self):
|
||||
items = self.inputs or []
|
||||
method = getattr(self._param, "sort_method", "asc") or "asc"
|
||||
reverse = method == "desc"
|
||||
|
||||
if not items:
|
||||
self._set_outputs([])
|
||||
return
|
||||
|
||||
first = items[0]
|
||||
|
||||
if isinstance(first, dict):
|
||||
outputs = sorted(
|
||||
items,
|
||||
key=lambda x: self._hashable(x),
|
||||
reverse=reverse,
|
||||
)
|
||||
else:
|
||||
outputs = sorted(items, reverse=reverse)
|
||||
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _drop_duplicates(self):
|
||||
seen = set()
|
||||
outs = []
|
||||
for item in self.inputs:
|
||||
k = self._hashable(item)
|
||||
if k in seen:
|
||||
continue
|
||||
seen.add(k)
|
||||
outs.append(item)
|
||||
self._set_outputs(outs)
|
||||
|
||||
def _hashable(self,x):
|
||||
if isinstance(x, dict):
|
||||
return tuple(sorted((k, self._hashable(v)) for k, v in x.items()))
|
||||
if isinstance(x, (list, tuple)):
|
||||
return tuple(self._hashable(v) for v in x)
|
||||
if isinstance(x, set):
|
||||
return tuple(sorted(self._hashable(v) for v in x))
|
||||
return x
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "ListOperation in progress"
|
||||
@ -222,7 +222,7 @@ class LLM(ComponentBase):
|
||||
output_structure = self._param.outputs['structured']
|
||||
except Exception:
|
||||
pass
|
||||
if output_structure:
|
||||
if output_structure and isinstance(output_structure, dict) and output_structure.get("properties"):
|
||||
schema=json.dumps(output_structure, ensure_ascii=False, indent=2)
|
||||
prompt += structured_output_prompt(schema)
|
||||
for _ in range(self._param.max_retries+1):
|
||||
|
||||
@ -17,6 +17,9 @@ import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import pypandoc
|
||||
import logging
|
||||
import tempfile
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
@ -24,7 +27,8 @@ from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from jinja2 import Template as Jinja2Template
|
||||
|
||||
from common.connection_utils import timeout
|
||||
|
||||
from common.misc_utils import get_uuid
|
||||
from common import settings
|
||||
|
||||
class MessageParam(ComponentParamBase):
|
||||
"""
|
||||
@ -34,6 +38,7 @@ class MessageParam(ComponentParamBase):
|
||||
super().__init__()
|
||||
self.content = []
|
||||
self.stream = True
|
||||
self.output_format = None # default output format
|
||||
self.outputs = {
|
||||
"content": {
|
||||
"type": "str"
|
||||
@ -133,6 +138,7 @@ class Message(ComponentBase):
|
||||
yield rand_cnt[s: ]
|
||||
|
||||
self.set_output("content", all_content)
|
||||
self._convert_content(all_content)
|
||||
|
||||
def _is_jinjia2(self, content:str) -> bool:
|
||||
patt = [
|
||||
@ -164,6 +170,68 @@ class Message(ComponentBase):
|
||||
content = re.sub(n, v, content)
|
||||
|
||||
self.set_output("content", content)
|
||||
self._convert_content(content)
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return ""
|
||||
|
||||
def _convert_content(self, content):
|
||||
doc_id = get_uuid()
|
||||
|
||||
if self._param.output_format.lower() not in {"markdown", "html", "pdf", "docx"}:
|
||||
self._param.output_format = "markdown"
|
||||
|
||||
try:
|
||||
if self._param.output_format in {"markdown", "html"}:
|
||||
if isinstance(content, str):
|
||||
converted = pypandoc.convert_text(
|
||||
content,
|
||||
to=self._param.output_format,
|
||||
format="markdown",
|
||||
)
|
||||
else:
|
||||
converted = pypandoc.convert_file(
|
||||
content,
|
||||
to=self._param.output_format,
|
||||
format="markdown",
|
||||
)
|
||||
|
||||
binary_content = converted.encode("utf-8")
|
||||
|
||||
else: # pdf, docx
|
||||
with tempfile.NamedTemporaryFile(suffix=f".{self._param.output_format}", delete=False) as tmp:
|
||||
tmp_name = tmp.name
|
||||
|
||||
try:
|
||||
if isinstance(content, str):
|
||||
pypandoc.convert_text(
|
||||
content,
|
||||
to=self._param.output_format,
|
||||
format="markdown",
|
||||
outputfile=tmp_name,
|
||||
)
|
||||
else:
|
||||
pypandoc.convert_file(
|
||||
content,
|
||||
to=self._param.output_format,
|
||||
format="markdown",
|
||||
outputfile=tmp_name,
|
||||
)
|
||||
|
||||
with open(tmp_name, "rb") as f:
|
||||
binary_content = f.read()
|
||||
|
||||
finally:
|
||||
if os.path.exists(tmp_name):
|
||||
os.remove(tmp_name)
|
||||
|
||||
settings.STORAGE_IMPL.put(self._canvas._tenant_id, doc_id, binary_content)
|
||||
self.set_output("attachment", {
|
||||
"doc_id":doc_id,
|
||||
"format":self._param.output_format,
|
||||
"file_name":f"{doc_id[:8]}.{self._param.output_format}"})
|
||||
|
||||
logging.info(f"Converted content uploaded as {doc_id} (format={self._param.output_format})")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error converting content to {self._param.output_format}: {e}")
|
||||
@ -83,10 +83,10 @@
|
||||
"value": []
|
||||
}
|
||||
},
|
||||
"password": "20010812Yy!",
|
||||
"password": "",
|
||||
"port": 3306,
|
||||
"sql": "{Agent:WickedGoatsDivide@content}",
|
||||
"username": "13637682833@163.com"
|
||||
"username": ""
|
||||
}
|
||||
},
|
||||
"upstream": [
|
||||
@ -527,10 +527,10 @@
|
||||
"value": []
|
||||
}
|
||||
},
|
||||
"password": "20010812Yy!",
|
||||
"password": "",
|
||||
"port": 3306,
|
||||
"sql": "{Agent:WickedGoatsDivide@content}",
|
||||
"username": "13637682833@163.com"
|
||||
"username": ""
|
||||
},
|
||||
"label": "ExeSQL",
|
||||
"name": "ExeSQL"
|
||||
|
||||
@ -21,9 +21,8 @@ from functools import partial
|
||||
from typing import TypedDict, List, Any
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
from common.misc_utils import hash_str2int
|
||||
from rag.llm.chat_model import ToolCallSession
|
||||
from rag.prompts.generator import kb_prompt
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession
|
||||
from timeit import default_timer as timer
|
||||
|
||||
|
||||
|
||||
@ -96,12 +96,12 @@ login_manager.init_app(app)
|
||||
commands.register_commands(app)
|
||||
|
||||
|
||||
def search_pages_path(pages_dir):
|
||||
def search_pages_path(page_path):
|
||||
app_path_list = [
|
||||
path for path in pages_dir.glob("*_app.py") if not path.name.startswith(".")
|
||||
path for path in page_path.glob("*_app.py") if not path.name.startswith(".")
|
||||
]
|
||||
api_path_list = [
|
||||
path for path in pages_dir.glob("*sdk/*.py") if not path.name.startswith(".")
|
||||
path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")
|
||||
]
|
||||
app_path_list.extend(api_path_list)
|
||||
return app_path_list
|
||||
@ -138,7 +138,7 @@ pages_dir = [
|
||||
]
|
||||
|
||||
client_urls_prefix = [
|
||||
register_page(path) for dir in pages_dir for path in search_pages_path(dir)
|
||||
register_page(path) for directory in pages_dir for path in search_pages_path(directory)
|
||||
]
|
||||
|
||||
|
||||
@ -177,5 +177,7 @@ def load_user(web_request):
|
||||
|
||||
|
||||
@app.teardown_request
|
||||
def _db_close(exc):
|
||||
def _db_close(exception):
|
||||
if exception:
|
||||
logging.exception(f"Request failed: {exception}")
|
||||
close_connection()
|
||||
|
||||
@ -13,41 +13,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from flask import request, Response
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.db import VALID_FILE_TYPES, FileType
|
||||
from api.db.db_models import APIToken, Task, File
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import APITokenService, API4ConversationService
|
||||
from api.db.services.dialog_service import DialogService, chat
|
||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import queue_tasks, TaskService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import RetCode, VALID_TASK_STATUS, LLMType, ParserType, FileSource
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, get_json_result, validate_request, \
|
||||
generate_confirmation_token
|
||||
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import keyword_extraction
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from agent.canvas import Canvas
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route('/new_token', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@ -138,758 +113,3 @@ def stats():
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/new_conversation', methods=['GET']) # noqa: F821
|
||||
def set_conversation():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
if objs[0].source == "agent":
|
||||
e, cvs = UserCanvasService.get_by_id(objs[0].dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": cvs.id,
|
||||
"user_id": request.args.get("user_id", ""),
|
||||
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
||||
"source": "agent"
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
return get_json_result(data=conv)
|
||||
else:
|
||||
e, dia = DialogService.get_by_id(objs[0].dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found")
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": dia.id,
|
||||
"user_id": request.args.get("user_id", ""),
|
||||
"message": [{"role": "assistant", "content": dia.prompt_config["prologue"]}]
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/completion', methods=['POST']) # noqa: F821
|
||||
@validate_request("conversation_id", "messages")
|
||||
def completion():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
req = request.json
|
||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
if "quote" not in req:
|
||||
req["quote"] = False
|
||||
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":
|
||||
continue
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
def fillin_conv(ans):
|
||||
nonlocal conv, message_id
|
||||
if not conv.reference:
|
||||
conv.reference.append(ans["reference"])
|
||||
else:
|
||||
conv.reference[-1] = ans["reference"]
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
||||
ans["id"] = message_id
|
||||
|
||||
def rename_field(ans):
|
||||
reference = ans['reference']
|
||||
if not isinstance(reference, dict):
|
||||
return
|
||||
for chunk_i in reference.get('chunks', []):
|
||||
if 'docnm_kwd' in chunk_i:
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
chunk_i.pop('docnm_kwd')
|
||||
|
||||
try:
|
||||
if conv.source == "agent":
|
||||
stream = req.get("stream", True)
|
||||
conv.message.append(msg[-1])
|
||||
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
del req["conversation_id"]
|
||||
del req["messages"]
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
final_ans = {"reference": [], "content": ""}
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
|
||||
canvas.messages.append(msg[-1])
|
||||
canvas.add_user_input(msg[-1]["content"])
|
||||
answer = canvas.run(stream=stream)
|
||||
|
||||
assert answer is not None, "Nothing. Is it over?"
|
||||
|
||||
if stream:
|
||||
assert isinstance(answer, partial), "Nothing. Is it over?"
|
||||
|
||||
def sse():
|
||||
nonlocal answer, cvs, conv
|
||||
try:
|
||||
for ans in answer():
|
||||
for k in ans.keys():
|
||||
final_ans[k] = ans[k]
|
||||
ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
|
||||
fillin_conv(ans)
|
||||
rename_field(ans)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
canvas.history.append(("assistant", final_ans["content"]))
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(sse(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
|
||||
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
||||
fillin_conv(result)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
rename_field(result)
|
||||
return get_json_result(data=result)
|
||||
|
||||
# ******************For dialog******************
|
||||
conv.message.append(msg[-1])
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
del req["conversation_id"]
|
||||
del req["messages"]
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
def stream():
|
||||
nonlocal dia, msg, req, conv
|
||||
try:
|
||||
for ans in chat(dia, msg, True, **req):
|
||||
fillin_conv(ans)
|
||||
rename_field(ans)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
if req.get("stream", True):
|
||||
resp = Response(stream(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
answer = None
|
||||
for ans in chat(dia, msg, **req):
|
||||
answer = ans
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
break
|
||||
rename_field(answer)
|
||||
return get_json_result(data=answer)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/conversation/<conversation_id>', methods=['GET']) # noqa: F821
|
||||
# @login_required
|
||||
def get_conversation(conversation_id):
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
try:
|
||||
e, conv = API4ConversationService.get_by_id(conversation_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
|
||||
conv = conv.to_dict()
|
||||
if token != APIToken.query(dialog_id=conv['dialog_id'])[0].token:
|
||||
return get_json_result(data=False, message='Authentication error: API key is invalid for this conversation_id!"',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
for referenct_i in conv['reference']:
|
||||
if referenct_i is None or len(referenct_i) == 0:
|
||||
continue
|
||||
for chunk_i in referenct_i['chunks']:
|
||||
if 'docnm_kwd' in chunk_i.keys():
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
chunk_i.pop('docnm_kwd')
|
||||
return get_json_result(data=conv)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/document/upload', methods=['POST']) # noqa: F821
|
||||
@validate_request("kb_name")
|
||||
def upload():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
kb_name = request.form.get("kb_name").strip()
|
||||
tenant_id = objs[0].tenant_id
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
kb_id = kb.id
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, tenant_id)
|
||||
kb_root_folder = FileService.get_kb_folder(tenant_id)
|
||||
kb_folder = FileService.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
|
||||
|
||||
try:
|
||||
if DocumentService.get_doc_count(kb.tenant_id) >= int(os.environ.get('MAX_FILE_NUM_PER_USER', 8192)):
|
||||
return get_data_error_result(
|
||||
message="Exceed the maximum file number of a free user!")
|
||||
|
||||
filename = duplicate_name(
|
||||
DocumentService.query,
|
||||
name=file.filename,
|
||||
kb_id=kb_id)
|
||||
filetype = filename_type(filename)
|
||||
if not filetype:
|
||||
return get_data_error_result(
|
||||
message="This type of file has not been supported yet!")
|
||||
|
||||
location = filename
|
||||
while settings.STORAGE_IMPL.obj_exist(kb_id, location):
|
||||
location += "_"
|
||||
blob = request.files['file'].read()
|
||||
settings.STORAGE_IMPL.put(kb_id, location, blob)
|
||||
doc = {
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": kb.parser_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": kb.tenant_id,
|
||||
"type": filetype,
|
||||
"name": filename,
|
||||
"location": location,
|
||||
"size": len(blob),
|
||||
"thumbnail": thumbnail(filename, blob),
|
||||
"suffix": Path(filename).suffix.lstrip("."),
|
||||
}
|
||||
|
||||
form_data = request.form
|
||||
if "parser_id" in form_data.keys():
|
||||
if request.form.get("parser_id").strip() in list(vars(ParserType).values())[1:-3]:
|
||||
doc["parser_id"] = request.form.get("parser_id").strip()
|
||||
if doc["type"] == FileType.VISUAL:
|
||||
doc["parser_id"] = ParserType.PICTURE.value
|
||||
if doc["type"] == FileType.AURAL:
|
||||
doc["parser_id"] = ParserType.AUDIO.value
|
||||
if re.search(r"\.(ppt|pptx|pages)$", filename):
|
||||
doc["parser_id"] = ParserType.PRESENTATION.value
|
||||
if re.search(r"\.(eml)$", filename):
|
||||
doc["parser_id"] = ParserType.EMAIL.value
|
||||
|
||||
doc_result = DocumentService.insert(doc)
|
||||
FileService.add_file_from_kb(doc, kb_folder["id"], kb.tenant_id)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
if "run" in form_data.keys():
|
||||
if request.form.get("run").strip() == "1":
|
||||
try:
|
||||
info = {"run": 1, "progress": 0, "progress_msg": "", "chunk_num": 0, "token_num": 0}
|
||||
DocumentService.update_by_id(doc["id"], info)
|
||||
# if str(req["run"]) == TaskStatus.CANCEL.value:
|
||||
tenant_id = DocumentService.get_tenant_id(doc["id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
# e, doc = DocumentService.get_by_id(doc["id"])
|
||||
TaskService.filter_delete([Task.doc_id == doc["id"]])
|
||||
e, doc = DocumentService.get_by_id(doc["id"])
|
||||
doc = doc.to_dict()
|
||||
doc["tenant_id"] = tenant_id
|
||||
bucket, name = File2DocumentService.get_storage_address(doc_id=doc["id"])
|
||||
queue_tasks(doc, bucket, name, 0)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
return get_json_result(data=doc_result.to_json())
|
||||
|
||||
|
||||
@manager.route('/document/upload_and_parse', methods=['POST']) # noqa: F821
|
||||
@validate_request("conversation_id")
|
||||
def upload_parse():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
if 'file' not in request.files:
|
||||
return get_json_result(
|
||||
data=False, message='No file part!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
file_objs = request.files.getlist('file')
|
||||
for file_obj in file_objs:
|
||||
if file_obj.filename == '':
|
||||
return get_json_result(
|
||||
data=False, message='No file selected!', code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
doc_ids = doc_upload_and_parse(request.form.get("conversation_id"), file_objs, objs[0].tenant_id)
|
||||
return get_json_result(data=doc_ids)
|
||||
|
||||
|
||||
@manager.route('/list_chunks', methods=['POST']) # noqa: F821
|
||||
# @login_required
|
||||
def list_chunks():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
|
||||
try:
|
||||
if "doc_name" in req.keys():
|
||||
tenant_id = DocumentService.get_tenant_id_by_name(req['doc_name'])
|
||||
doc_id = DocumentService.get_doc_id_by_doc_name(req['doc_name'])
|
||||
|
||||
elif "doc_id" in req.keys():
|
||||
tenant_id = DocumentService.get_tenant_id(req['doc_id'])
|
||||
doc_id = req['doc_id']
|
||||
else:
|
||||
return get_json_result(
|
||||
data=False, message="Can't find doc_name or doc_id"
|
||||
)
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
|
||||
res = settings.retriever.chunk_list(doc_id, tenant_id, kb_ids)
|
||||
res = [
|
||||
{
|
||||
"content": res_item["content_with_weight"],
|
||||
"doc_name": res_item["docnm_kwd"],
|
||||
"image_id": res_item["img_id"]
|
||||
} for res_item in res
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
return get_json_result(data=res)
|
||||
|
||||
@manager.route('/get_chunk/<chunk_id>', methods=['GET']) # noqa: F821
|
||||
# @login_required
|
||||
def get_chunk(chunk_id):
|
||||
from rag.nlp import search
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
tenant_id = objs[0].tenant_id
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(tenant_id)
|
||||
chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), kb_ids)
|
||||
if chunk is None:
|
||||
return server_error_response(Exception("Chunk not found"))
|
||||
k = []
|
||||
for n in chunk.keys():
|
||||
if re.search(r"(_vec$|_sm_|_tks|_ltks)", n):
|
||||
k.append(n)
|
||||
for n in k:
|
||||
del chunk[n]
|
||||
|
||||
return get_json_result(data=chunk)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route('/list_kb_docs', methods=['POST']) # noqa: F821
|
||||
# @login_required
|
||||
def list_kb_docs():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
tenant_id = objs[0].tenant_id
|
||||
kb_name = req.get("kb_name", "").strip()
|
||||
|
||||
try:
|
||||
e, kb = KnowledgebaseService.get_by_name(kb_name, tenant_id)
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this knowledgebase!")
|
||||
kb_id = kb.id
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
page_number = int(req.get("page", 1))
|
||||
items_per_page = int(req.get("page_size", 15))
|
||||
orderby = req.get("orderby", "create_time")
|
||||
desc = req.get("desc", True)
|
||||
keywords = req.get("keywords", "")
|
||||
status = req.get("status", [])
|
||||
if status:
|
||||
invalid_status = {s for s in status if s not in VALID_TASK_STATUS}
|
||||
if invalid_status:
|
||||
return get_data_error_result(
|
||||
message=f"Invalid filter status conditions: {', '.join(invalid_status)}"
|
||||
)
|
||||
types = req.get("types", [])
|
||||
if types:
|
||||
invalid_types = {t for t in types if t not in VALID_FILE_TYPES}
|
||||
if invalid_types:
|
||||
return get_data_error_result(
|
||||
message=f"Invalid filter conditions: {', '.join(invalid_types)} type{'s' if len(invalid_types) > 1 else ''}"
|
||||
)
|
||||
try:
|
||||
docs, tol = DocumentService.get_by_kb_id(
|
||||
kb_id, page_number, items_per_page, orderby, desc, keywords, status, types)
|
||||
docs = [{"doc_id": doc['id'], "doc_name": doc['name']} for doc in docs]
|
||||
|
||||
return get_json_result(data={"total": tol, "docs": docs})
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/document/infos', methods=['POST']) # noqa: F821
|
||||
@validate_request("doc_ids")
|
||||
def docinfos():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
req = request.json
|
||||
doc_ids = req["doc_ids"]
|
||||
docs = DocumentService.get_by_ids(doc_ids)
|
||||
return get_json_result(data=list(docs.dicts()))
|
||||
|
||||
|
||||
@manager.route('/document', methods=['DELETE']) # noqa: F821
|
||||
# @login_required
|
||||
def document_rm():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
tenant_id = objs[0].tenant_id
|
||||
req = request.json
|
||||
try:
|
||||
doc_ids = DocumentService.get_doc_ids_by_doc_names(req.get("doc_names", []))
|
||||
for doc_id in req.get("doc_ids", []):
|
||||
if doc_id not in doc_ids:
|
||||
doc_ids.append(doc_id)
|
||||
|
||||
if not doc_ids:
|
||||
return get_json_result(
|
||||
data=False, message="Can't find doc_names or doc_ids"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
root_folder = FileService.get_root_folder(tenant_id)
|
||||
pf_id = root_folder["id"]
|
||||
FileService.init_knowledgebase_docs(pf_id, tenant_id)
|
||||
|
||||
errors = ""
|
||||
docs = DocumentService.get_by_ids(doc_ids)
|
||||
doc_dic = {}
|
||||
for doc in docs:
|
||||
doc_dic[doc.id] = doc
|
||||
|
||||
for doc_id in doc_ids:
|
||||
try:
|
||||
if doc_id not in doc_dic:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
doc = doc_dic[doc_id]
|
||||
tenant_id = DocumentService.get_tenant_id(doc_id)
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
b, n = File2DocumentService.get_storage_address(doc_id=doc_id)
|
||||
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_data_error_result(
|
||||
message="Database error (Document removal)!")
|
||||
|
||||
f2d = File2DocumentService.get_by_document_id(doc_id)
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc_id)
|
||||
|
||||
settings.STORAGE_IMPL.rm(b, n)
|
||||
except Exception as e:
|
||||
errors += str(e)
|
||||
|
||||
if errors:
|
||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/completion_aibotk', methods=['POST']) # noqa: F821
|
||||
@validate_request("Authorization", "conversation_id", "word")
|
||||
def completion_faq():
|
||||
import base64
|
||||
req = request.json
|
||||
|
||||
token = req["Authorization"]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
e, conv = API4ConversationService.get_by_id(req["conversation_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
if "quote" not in req:
|
||||
req["quote"] = True
|
||||
|
||||
msg = [{"role": "user", "content": req["word"]}]
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
message_id = msg[-1]["id"]
|
||||
|
||||
def fillin_conv(ans):
|
||||
nonlocal conv, message_id
|
||||
if not conv.reference:
|
||||
conv.reference.append(ans["reference"])
|
||||
else:
|
||||
conv.reference[-1] = ans["reference"]
|
||||
conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
|
||||
ans["id"] = message_id
|
||||
|
||||
try:
|
||||
if conv.source == "agent":
|
||||
conv.message.append(msg[-1])
|
||||
e, cvs = UserCanvasService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return server_error_response("canvas not found.")
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
final_ans = {"reference": [], "doc_aggs": []}
|
||||
canvas = Canvas(cvs.dsl, objs[0].tenant_id)
|
||||
|
||||
canvas.messages.append(msg[-1])
|
||||
canvas.add_user_input(msg[-1]["content"])
|
||||
answer = canvas.run(stream=False)
|
||||
|
||||
assert answer is not None, "Nothing. Is it over?"
|
||||
|
||||
data_type_picture = {
|
||||
"type": 3,
|
||||
"url": "base64 content"
|
||||
}
|
||||
data = [
|
||||
{
|
||||
"type": 1,
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
|
||||
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
|
||||
if final_ans.get("reference"):
|
||||
canvas.reference.append(final_ans["reference"])
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
|
||||
ans = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
||||
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
|
||||
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
||||
for chunk_idx in chunk_idxs[:1]:
|
||||
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
||||
try:
|
||||
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
||||
response = settings.STORAGE_IMPL.get(bkt, nm)
|
||||
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
||||
data.append(data_type_picture)
|
||||
break
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
response = {"code": 200, "msg": "success", "data": data}
|
||||
return response
|
||||
|
||||
# ******************For dialog******************
|
||||
conv.message.append(msg[-1])
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Dialog not found!")
|
||||
del req["conversation_id"]
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
data_type_picture = {
|
||||
"type": 3,
|
||||
"url": "base64 content"
|
||||
}
|
||||
data = [
|
||||
{
|
||||
"type": 1,
|
||||
"content": ""
|
||||
}
|
||||
]
|
||||
ans = ""
|
||||
for a in chat(dia, msg, stream=False, **req):
|
||||
ans = a
|
||||
break
|
||||
data[0]["content"] += re.sub(r'##\d\$\$', '', ans["answer"])
|
||||
fillin_conv(ans)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
|
||||
chunk_idxs = [int(match[2]) for match in re.findall(r'##\d\$\$', ans["answer"])]
|
||||
for chunk_idx in chunk_idxs[:1]:
|
||||
if ans["reference"]["chunks"][chunk_idx]["img_id"]:
|
||||
try:
|
||||
bkt, nm = ans["reference"]["chunks"][chunk_idx]["img_id"].split("-")
|
||||
response = settings.STORAGE_IMPL.get(bkt, nm)
|
||||
data_type_picture["url"] = base64.b64encode(response).decode('utf-8')
|
||||
data.append(data_type_picture)
|
||||
break
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
response = {"code": 200, "msg": "success", "data": data}
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/retrieval', methods=['POST']) # noqa: F821
|
||||
@validate_request("kb_id", "question")
|
||||
def retrieval():
|
||||
token = request.headers.get('Authorization').split()[1]
|
||||
objs = APIToken.query(token=token)
|
||||
if not objs:
|
||||
return get_json_result(
|
||||
data=False, message='Authentication error: API key is invalid!"', code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
req = request.json
|
||||
kb_ids = req.get("kb_id", [])
|
||||
doc_ids = req.get("doc_ids", [])
|
||||
question = req.get("question")
|
||||
page = int(req.get("page", 1))
|
||||
size = int(req.get("page_size", 30))
|
||||
similarity_threshold = float(req.get("similarity_threshold", 0.2))
|
||||
vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3))
|
||||
top = int(req.get("top_k", 1024))
|
||||
highlight = bool(req.get("highlight", False))
|
||||
|
||||
try:
|
||||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||
if len(embd_nms) != 1:
|
||||
return get_json_result(
|
||||
data=False, message='Knowledge bases use different embedding models or does not exist."',
|
||||
code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
embd_mdl = LLMBundle(kbs[0].tenant_id, LLMType.EMBEDDING, llm_name=kbs[0].embd_id)
|
||||
rerank_mdl = None
|
||||
if req.get("rerank_id"):
|
||||
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kbs[0].tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
ranks = settings.retriever.retrieval(question, embd_mdl, kbs[0].tenant_id, kb_ids, page, size,
|
||||
similarity_threshold, vector_similarity_weight, top,
|
||||
doc_ids, rerank_mdl=rerank_mdl, highlight= highlight,
|
||||
rank_feature=label_question(question, kbs))
|
||||
for c in ranks["chunks"]:
|
||||
c.pop("vector", None)
|
||||
return get_json_result(data=ranks)
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||
code=RetCode.DATA_ERROR)
|
||||
return server_error_response(e)
|
||||
|
||||
@ -426,7 +426,6 @@ def test_db_connect():
|
||||
try:
|
||||
import trino
|
||||
import os
|
||||
from trino.auth import BasicAuthentication
|
||||
except Exception as e:
|
||||
return server_error_response(f"Missing dependency 'trino'. Please install: pip install trino, detail: {e}")
|
||||
|
||||
@ -438,7 +437,7 @@ def test_db_connect():
|
||||
|
||||
auth = None
|
||||
if http_scheme == "https" and req.get("password"):
|
||||
auth = BasicAuthentication(req.get("username") or "ragflow", req["password"])
|
||||
auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"])
|
||||
|
||||
conn = trino.dbapi.connect(
|
||||
host=req["host"],
|
||||
@ -471,8 +470,8 @@ def test_db_connect():
|
||||
@login_required
|
||||
def getlistversion(canvas_id):
|
||||
try:
|
||||
list =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
|
||||
return get_json_result(data=list)
|
||||
versions =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
|
||||
return get_json_result(data=versions)
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=f"Error getting history files: {e}")
|
||||
|
||||
|
||||
@ -55,7 +55,6 @@ def set_connector():
|
||||
"timeout_secs": int(req.get("timeout_secs", 60 * 29)),
|
||||
"status": TaskStatus.SCHEDULE,
|
||||
}
|
||||
conn["status"] = TaskStatus.SCHEDULE
|
||||
ConnectorService.save(**conn)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
@ -85,7 +85,6 @@ def get():
|
||||
if not e:
|
||||
return get_data_error_result(message="Conversation not found!")
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
avatar = None
|
||||
for tenant in tenants:
|
||||
dialog = DialogService.query(tenant_id=tenant.tenant_id, id=conv.dialog_id)
|
||||
if dialog and len(dialog) > 0:
|
||||
|
||||
@ -154,15 +154,15 @@ def get_kb_names(kb_ids):
|
||||
@login_required
|
||||
def list_dialogs():
|
||||
try:
|
||||
diags = DialogService.query(
|
||||
conversations = DialogService.query(
|
||||
tenant_id=current_user.id,
|
||||
status=StatusEnum.VALID.value,
|
||||
reverse=True,
|
||||
order_by=DialogService.model.create_time)
|
||||
diags = [d.to_dict() for d in diags]
|
||||
for d in diags:
|
||||
d["kb_ids"], d["kb_names"] = get_kb_names(d["kb_ids"])
|
||||
return get_json_result(data=diags)
|
||||
conversations = [d.to_dict() for d in conversations]
|
||||
for conversation in conversations:
|
||||
conversation["kb_ids"], conversation["kb_names"] = get_kb_names(conversation["kb_ids"])
|
||||
return get_json_result(data=conversations)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@ -308,7 +308,7 @@ def get_filter():
|
||||
|
||||
@manager.route("/infos", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
def docinfos():
|
||||
def doc_infos():
|
||||
req = request.json
|
||||
doc_ids = req["doc_ids"]
|
||||
for doc_id in doc_ids:
|
||||
@ -508,6 +508,7 @@ def get(doc_id):
|
||||
ext = ext.group(1) if ext else None
|
||||
if ext:
|
||||
if doc.type == FileType.VISUAL.value:
|
||||
|
||||
content_type = CONTENT_TYPE_MAP.get(ext, f"image/{ext}")
|
||||
else:
|
||||
content_type = CONTENT_TYPE_MAP.get(ext, f"application/{ext}")
|
||||
@ -517,6 +518,22 @@ def get(doc_id):
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/download/<attachment_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def download_attachment(attachment_id):
|
||||
try:
|
||||
ext = request.args.get("ext", "markdown")
|
||||
data = settings.STORAGE_IMPL.get(current_user.id, attachment_id)
|
||||
# data = settings.STORAGE_IMPL.get("eb500d50bb0411f0907561d2782adda5", attachment_id)
|
||||
response = flask.make_response(data)
|
||||
response.headers.set("Content-Type", CONTENT_TYPE_MAP.get(ext, f"application/{ext}"))
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/change_parser", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id")
|
||||
@ -544,6 +561,7 @@ def change_parser():
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
return None
|
||||
|
||||
try:
|
||||
if "pipeline_id" in req and req["pipeline_id"] != "":
|
||||
|
||||
@ -246,8 +246,8 @@ def rm():
|
||||
try:
|
||||
if file.location:
|
||||
settings.STORAGE_IMPL.rm(file.parent_id, file.location)
|
||||
except Exception:
|
||||
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}")
|
||||
except Exception as e:
|
||||
logging.exception(f"Fail to remove object: {file.parent_id}/{file.location}, error: {e}")
|
||||
|
||||
informs = File2DocumentService.get_by_file_id(file.id)
|
||||
for inform in informs:
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
@ -54,6 +55,10 @@ def create():
|
||||
**req
|
||||
)
|
||||
|
||||
code = req.get("code")
|
||||
if code:
|
||||
return get_data_error_result(code=code, message=req.get("message"))
|
||||
|
||||
try:
|
||||
if not KnowledgebaseService.save(**req):
|
||||
return get_data_error_result()
|
||||
@ -731,6 +736,8 @@ def delete_kb_task():
|
||||
def cancel_task(task_id):
|
||||
REDIS_CONN.set(f"{task_id}-cancel", "x")
|
||||
|
||||
kb_task_id_field: str = ""
|
||||
kb_task_finish_at: str = ""
|
||||
match pipeline_task_type:
|
||||
case PipelineTaskType.GRAPH_RAG:
|
||||
kb_task_id_field = "graphrag_task_id"
|
||||
@ -807,7 +814,7 @@ def check_embedding():
|
||||
offset=0, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
)
|
||||
total = docStoreConn.getTotal(res0)
|
||||
total = docStoreConn.get_total(res0)
|
||||
if total <= 0:
|
||||
return []
|
||||
|
||||
@ -824,7 +831,7 @@ def check_embedding():
|
||||
offset=off, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
)
|
||||
ids = docStoreConn.getChunkIds(res1)
|
||||
ids = docStoreConn.get_chunk_ids(res1)
|
||||
if not ids:
|
||||
continue
|
||||
|
||||
@ -845,8 +852,13 @@ def check_embedding():
|
||||
"position_int": full_doc.get("position_int"),
|
||||
"top_int": full_doc.get("top_int"),
|
||||
"content_with_weight": full_doc.get("content_with_weight") or "",
|
||||
"question_kwd": full_doc.get("question_kwd") or []
|
||||
})
|
||||
return out
|
||||
|
||||
def _clean(s: str) -> str:
|
||||
s = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", s or "")
|
||||
return s if s else "None"
|
||||
req = request.json
|
||||
kb_id = req.get("kb_id", "")
|
||||
embd_id = req.get("embd_id", "")
|
||||
@ -859,8 +871,10 @@ def check_embedding():
|
||||
|
||||
results, eff_sims = [], []
|
||||
for ck in samples:
|
||||
txt = (ck.get("content_with_weight") or "").strip()
|
||||
if not txt:
|
||||
title = ck.get("doc_name") or "Title"
|
||||
txt_in = "\n".join(ck.get("question_kwd") or []) or ck.get("content_with_weight") or ""
|
||||
txt_in = _clean(txt_in)
|
||||
if not txt_in:
|
||||
results.append({"chunk_id": ck["chunk_id"], "reason": "no_text"})
|
||||
continue
|
||||
|
||||
@ -869,8 +883,16 @@ def check_embedding():
|
||||
continue
|
||||
|
||||
try:
|
||||
qv, _ = emb_mdl.encode_queries(txt)
|
||||
sim = _cos_sim(qv, ck["vector"])
|
||||
v, _ = emb_mdl.encode([title, txt_in])
|
||||
sim_content = _cos_sim(v[1], ck["vector"])
|
||||
title_w = 0.1
|
||||
qv_mix = title_w * v[0] + (1 - title_w) * v[1]
|
||||
sim_mix = _cos_sim(qv_mix, ck["vector"])
|
||||
sim = sim_content
|
||||
mode = "content_only"
|
||||
if sim_mix > sim:
|
||||
sim = sim_mix
|
||||
mode = "title+content"
|
||||
except Exception:
|
||||
return get_error_data_result(message="embedding failure")
|
||||
|
||||
@ -892,9 +914,10 @@ def check_embedding():
|
||||
"avg_cos_sim": round(float(np.mean(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"min_cos_sim": round(float(np.min(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"max_cos_sim": round(float(np.max(eff_sims)) if eff_sims else 0.0, 6),
|
||||
"match_mode": mode,
|
||||
}
|
||||
if summary["avg_cos_sim"] > 0.99:
|
||||
if summary["avg_cos_sim"] > 0.9:
|
||||
return get_json_result(data={"summary": summary, "results": results})
|
||||
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})
|
||||
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="Embedding model switch failed: the average similarity between old and new vectors is below 0.9, indicating incompatible vector spaces.", data={"summary": summary, "results": results})
|
||||
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||
get_mcp_tools
|
||||
from api.utils.web_utils import get_float, safe_json_parse
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
|
||||
@manager.route("/list", methods=["POST"]) # noqa: F821
|
||||
|
||||
@ -41,12 +41,12 @@ def list_agents(tenant_id):
|
||||
return get_error_data_result("The agent doesn't exist.")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 30))
|
||||
orderby = request.args.get("orderby", "update_time")
|
||||
order_by = request.args.get("orderby", "update_time")
|
||||
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, title)
|
||||
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, order_by, desc, id, title)
|
||||
return get_result(data=canvas)
|
||||
|
||||
|
||||
|
||||
@ -21,10 +21,11 @@ import json
|
||||
from flask import request
|
||||
from peewee import OperationalError
|
||||
from api.db.db_models import File
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.document_service import DocumentService, queue_raptor_o_graphrag_tasks
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import GRAPH_RAPTOR_FAKE_DOC_ID, TaskService
|
||||
from api.db.services.user_service import TenantService
|
||||
from common.constants import RetCode, FileSource, StatusEnum
|
||||
from api.utils.api_utils import (
|
||||
@ -118,7 +119,6 @@ def create(tenant_id):
|
||||
req, err = validate_and_parse_json_request(request, CreateDatasetReq)
|
||||
if err is not None:
|
||||
return get_error_argument_result(err)
|
||||
|
||||
req = KnowledgebaseService.create_with_name(
|
||||
name = req.pop("name", None),
|
||||
tenant_id = tenant_id,
|
||||
@ -144,7 +144,6 @@ def create(tenant_id):
|
||||
ok, k = KnowledgebaseService.get_by_id(req["id"])
|
||||
if not ok:
|
||||
return get_error_data_result(message="Dataset created failed")
|
||||
|
||||
response_data = remap_dictionary_keys(k.to_dict())
|
||||
return get_result(data=response_data)
|
||||
except Exception as e:
|
||||
@ -532,3 +531,157 @@ def delete_knowledge_graph(tenant_id, dataset_id):
|
||||
search.index_name(kb.tenant_id), dataset_id)
|
||||
|
||||
return get_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/run_graphrag", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def run_graphrag(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid GraphRAG task id is expected for Dataset {dataset_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A Graph Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=dataset_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="graphrag", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"graphrag_task_id": task_id}):
|
||||
logging.warning(f"Cannot save graphrag_task_id for Dataset {dataset_id}")
|
||||
|
||||
return get_result(data={"graphrag_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/trace_graphrag", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def trace_graphrag(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.graphrag_task_id
|
||||
if not task_id:
|
||||
return get_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_result(data={})
|
||||
|
||||
return get_result(data=task.to_dict())
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/run_raptor", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def run_raptor(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if task_id:
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
logging.warning(f"A valid RAPTOR task id is expected for Dataset {dataset_id}")
|
||||
|
||||
if task and task.progress not in [-1, 1]:
|
||||
return get_error_data_result(message=f"Task {task_id} in progress with status {task.progress}. A RAPTOR Task is already running.")
|
||||
|
||||
documents, _ = DocumentService.get_by_kb_id(
|
||||
kb_id=dataset_id,
|
||||
page_number=0,
|
||||
items_per_page=0,
|
||||
orderby="create_time",
|
||||
desc=False,
|
||||
keywords="",
|
||||
run_status=[],
|
||||
types=[],
|
||||
suffix=[],
|
||||
)
|
||||
if not documents:
|
||||
return get_error_data_result(message=f"No documents in Dataset {dataset_id}")
|
||||
|
||||
sample_document = documents[0]
|
||||
document_ids = [document["id"] for document in documents]
|
||||
|
||||
task_id = queue_raptor_o_graphrag_tasks(sample_doc_id=sample_document, ty="raptor", priority=0, fake_doc_id=GRAPH_RAPTOR_FAKE_DOC_ID, doc_ids=list(document_ids))
|
||||
|
||||
if not KnowledgebaseService.update_by_id(kb.id, {"raptor_task_id": task_id}):
|
||||
logging.warning(f"Cannot save raptor_task_id for Dataset {dataset_id}")
|
||||
|
||||
return get_result(data={"raptor_task_id": task_id})
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>/trace_raptor", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def trace_raptor(tenant_id,dataset_id):
|
||||
if not dataset_id:
|
||||
return get_error_data_result(message='Lack of "Dataset ID"')
|
||||
|
||||
if not KnowledgebaseService.accessible(dataset_id, tenant_id):
|
||||
return get_result(
|
||||
data=False,
|
||||
message='No authorization.',
|
||||
code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
ok, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="Invalid Dataset ID")
|
||||
|
||||
task_id = kb.raptor_task_id
|
||||
if not task_id:
|
||||
return get_result(data={})
|
||||
|
||||
ok, task = TaskService.get_by_id(task_id)
|
||||
if not ok:
|
||||
return get_error_data_result(message="RAPTOR Task Not Found or Error Occurred")
|
||||
|
||||
return get_result(data=task.to_dict())
|
||||
@ -93,6 +93,10 @@ def upload(dataset_id, tenant_id):
|
||||
type: file
|
||||
required: true
|
||||
description: Document files to upload.
|
||||
- in: formData
|
||||
name: parent_path
|
||||
type: string
|
||||
description: Optional nested path under the parent folder. Uses '/' separators.
|
||||
responses:
|
||||
200:
|
||||
description: Successfully uploaded documents.
|
||||
@ -151,7 +155,7 @@ def upload(dataset_id, tenant_id):
|
||||
e, kb = KnowledgebaseService.get_by_id(dataset_id)
|
||||
if not e:
|
||||
raise LookupError(f"Can't find the dataset with ID {dataset_id}!")
|
||||
err, files = FileService.upload_document(kb, file_objs, tenant_id)
|
||||
err, files = FileService.upload_document(kb, file_objs, tenant_id, parent_path=request.form.get("parent_path"))
|
||||
if err:
|
||||
return get_result(message="\n".join(err), code=RetCode.SERVER_ERROR)
|
||||
# rename key's name
|
||||
|
||||
@ -305,6 +305,7 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
||||
time.sleep(self.retry_delay * (2 ** attempt))
|
||||
else:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
class RetryingPooledPostgresqlDatabase(PooledPostgresqlDatabase):
|
||||
@ -772,7 +773,7 @@ class Document(DataBaseModel):
|
||||
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
|
||||
kb_id = CharField(max_length=256, null=False, index=True)
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", index=True)
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="pipleline ID", index=True)
|
||||
pipeline_id = CharField(max_length=32, null=True, help_text="pipeline ID", index=True)
|
||||
parser_config = JSONField(null=False, default={"pages": [[1, 1000000]]})
|
||||
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document come from", index=True)
|
||||
type = CharField(max_length=32, null=False, help_text="file extension", index=True)
|
||||
@ -876,7 +877,7 @@ class Dialog(DataBaseModel):
|
||||
class Conversation(DataBaseModel):
|
||||
id = CharField(max_length=32, primary_key=True)
|
||||
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||
name = CharField(max_length=255, null=True, help_text="converastion name", index=True)
|
||||
name = CharField(max_length=255, null=True, help_text="conversation name", index=True)
|
||||
message = JSONField(null=True)
|
||||
reference = JSONField(null=True, default=[])
|
||||
user_id = CharField(max_length=255, null=True, help_text="user_id", index=True)
|
||||
|
||||
@ -70,7 +70,7 @@ class ConnectorService(CommonService):
|
||||
def rebuild(cls, kb_id:str, connector_id: str, tenant_id:str):
|
||||
e, conn = cls.get_by_id(connector_id)
|
||||
if not e:
|
||||
return
|
||||
return None
|
||||
SyncLogsService.filter_delete([SyncLogs.connector_id==connector_id, SyncLogs.kb_id==kb_id])
|
||||
docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}", kb_id=kb_id)
|
||||
err = FileService.delete_docs([d.id for d in docs], tenant_id)
|
||||
@ -125,11 +125,11 @@ class SyncLogsService(CommonService):
|
||||
)
|
||||
|
||||
query = query.distinct().order_by(cls.model.update_time.desc())
|
||||
totbal = query.count()
|
||||
total = query.count()
|
||||
if page_number:
|
||||
query = query.paginate(page_number, items_per_page)
|
||||
|
||||
return list(query.dicts()), totbal
|
||||
return list(query.dicts()), total
|
||||
|
||||
@classmethod
|
||||
def start(cls, id, connector_id):
|
||||
@ -242,7 +242,7 @@ class Connector2KbService(CommonService):
|
||||
"id": get_uuid(),
|
||||
"connector_id": conn_id,
|
||||
"kb_id": kb_id,
|
||||
"auto_parse": conn.get("auto_parse", "1")
|
||||
"auto_parse": conn.get("auto_parse", "1")
|
||||
})
|
||||
SyncLogsService.schedule(conn_id, kb_id, reindex=True)
|
||||
|
||||
|
||||
@ -342,7 +342,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||
for ans in chat_solo(dialog, messages, stream):
|
||||
yield ans
|
||||
return
|
||||
return None
|
||||
|
||||
chat_start_ts = timer()
|
||||
|
||||
@ -386,7 +386,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||
if ans:
|
||||
yield ans
|
||||
return
|
||||
return None
|
||||
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["key"] == "knowledge":
|
||||
@ -617,6 +617,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
res["audio_binary"] = tts(tts_mdl, answer)
|
||||
yield res
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||
sys_prompt = """
|
||||
@ -745,7 +747,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
|
||||
def tts(tts_mdl, text):
|
||||
if not tts_mdl or not text:
|
||||
return
|
||||
return None
|
||||
bin = b""
|
||||
for chunk in tts_mdl.tts(text):
|
||||
bin += chunk
|
||||
|
||||
@ -113,7 +113,7 @@ class DocumentService(CommonService):
|
||||
def check_doc_health(cls, tenant_id: str, filename):
|
||||
import os
|
||||
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
|
||||
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(tenant_id) >= MAX_FILE_NUM_PER_USER:
|
||||
if 0 < MAX_FILE_NUM_PER_USER <= DocumentService.get_doc_count(tenant_id):
|
||||
raise RuntimeError("Exceed the maximum file number of a free user!")
|
||||
if len(filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
|
||||
raise RuntimeError("Exceed the maximum length of file name!")
|
||||
@ -309,7 +309,7 @@ class DocumentService(CommonService):
|
||||
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||
page * page_size, page_size, search.index_name(tenant_id),
|
||||
[doc.kb_id])
|
||||
chunk_ids = settings.docStoreConn.getChunkIds(chunks)
|
||||
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
|
||||
if not chunk_ids:
|
||||
break
|
||||
all_chunk_ids.extend(chunk_ids)
|
||||
@ -322,7 +322,7 @@ class DocumentService(CommonService):
|
||||
settings.STORAGE_IMPL.rm(doc.kb_id, doc.thumbnail)
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
graph_source = settings.docStoreConn.getFields(
|
||||
graph_source = settings.docStoreConn.get_fields(
|
||||
settings.docStoreConn.search(["source_id"], [], {"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]}, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [doc.kb_id]), ["source_id"]
|
||||
)
|
||||
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
|
||||
@ -464,7 +464,7 @@ class DocumentService(CommonService):
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
return docs[0]["tenant_id"]
|
||||
|
||||
@classmethod
|
||||
@ -473,7 +473,7 @@ class DocumentService(CommonService):
|
||||
docs = cls.model.select(cls.model.kb_id).where(cls.model.id == doc_id)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
return docs[0]["kb_id"]
|
||||
|
||||
@classmethod
|
||||
@ -486,7 +486,7 @@ class DocumentService(CommonService):
|
||||
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
return docs[0]["tenant_id"]
|
||||
|
||||
@classmethod
|
||||
@ -533,7 +533,7 @@ class DocumentService(CommonService):
|
||||
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
|
||||
docs = docs.dicts()
|
||||
if not docs:
|
||||
return
|
||||
return None
|
||||
return docs[0]["embd_id"]
|
||||
|
||||
@classmethod
|
||||
@ -569,7 +569,7 @@ class DocumentService(CommonService):
|
||||
.where(cls.model.name == doc_name)
|
||||
doc_id = doc_id.dicts()
|
||||
if not doc_id:
|
||||
return
|
||||
return None
|
||||
return doc_id[0]["id"]
|
||||
|
||||
@classmethod
|
||||
@ -715,7 +715,7 @@ class DocumentService(CommonService):
|
||||
prg = 1
|
||||
status = TaskStatus.DONE.value
|
||||
|
||||
# only for special task and parsed docs and unfinised
|
||||
# only for special task and parsed docs and unfinished
|
||||
freeze_progress = special_task_running and doc_progress >= 1 and not finished
|
||||
msg = "\n".join(sorted(msg))
|
||||
info = {
|
||||
@ -974,13 +974,13 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
|
||||
def embedding(doc_id, cnts, batch_size=16):
|
||||
nonlocal embd_mdl, chunk_counts, token_counts
|
||||
vects = []
|
||||
vectors = []
|
||||
for i in range(0, len(cnts), batch_size):
|
||||
vts, c = embd_mdl.encode(cnts[i: i + batch_size])
|
||||
vects.extend(vts.tolist())
|
||||
vectors.extend(vts.tolist())
|
||||
chunk_counts[doc_id] += len(cnts[i:i + batch_size])
|
||||
token_counts[doc_id] += c
|
||||
return vects
|
||||
return vectors
|
||||
|
||||
idxnm = search.index_name(kb.tenant_id)
|
||||
try_create_idx = True
|
||||
@ -1011,15 +1011,15 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
except Exception:
|
||||
logging.exception("Mind map generation error")
|
||||
|
||||
vects = embedding(doc_id, [c["content_with_weight"] for c in cks])
|
||||
assert len(cks) == len(vects)
|
||||
vectors = embedding(doc_id, [c["content_with_weight"] for c in cks])
|
||||
assert len(cks) == len(vectors)
|
||||
for i, d in enumerate(cks):
|
||||
v = vects[i]
|
||||
v = vectors[i]
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
if try_create_idx:
|
||||
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vects[0]))
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vectors[0]))
|
||||
try_create_idx = False
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ from common.misc_utils import get_uuid
|
||||
from common.constants import TaskStatus, FileSource, ParserType
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.task_service import TaskService
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
|
||||
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img, sanitize_path
|
||||
from rag.llm.cv_model import GptV4
|
||||
from common import settings
|
||||
|
||||
@ -329,7 +329,7 @@ class FileService(CommonService):
|
||||
current_id = start_id
|
||||
while current_id:
|
||||
e, file = cls.get_by_id(current_id)
|
||||
if file.parent_id != file.id and e:
|
||||
if e and file.parent_id != file.id:
|
||||
parent_folders.append(file)
|
||||
current_id = file.parent_id
|
||||
else:
|
||||
@ -423,13 +423,15 @@ class FileService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def upload_document(self, kb, file_objs, user_id, src="local"):
|
||||
def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str | None = None):
|
||||
root_folder = self.get_root_folder(user_id)
|
||||
pf_id = root_folder["id"]
|
||||
self.init_knowledgebase_docs(pf_id, user_id)
|
||||
kb_root_folder = self.get_kb_folder(user_id)
|
||||
kb_folder = self.new_a_file_from_kb(kb.tenant_id, kb.name, kb_root_folder["id"])
|
||||
|
||||
safe_parent_path = sanitize_path(parent_path)
|
||||
|
||||
err, files = [], []
|
||||
for file in file_objs:
|
||||
try:
|
||||
@ -439,7 +441,7 @@ class FileService(CommonService):
|
||||
if filetype == FileType.OTHER.value:
|
||||
raise RuntimeError("This type of file has not been supported yet!")
|
||||
|
||||
location = filename
|
||||
location = filename if not safe_parent_path else f"{safe_parent_path}/{filename}"
|
||||
while settings.STORAGE_IMPL.obj_exist(kb.id, location):
|
||||
location += "_"
|
||||
|
||||
|
||||
@ -24,9 +24,9 @@ from common.time_utils import current_timestamp, datetime_format
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.user_service import TenantService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.constants import StatusEnum
|
||||
from common.constants import StatusEnum, RetCode
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from api.utils.api_utils import get_parser_config, get_data_error_result
|
||||
from api.utils.api_utils import get_parser_config
|
||||
|
||||
class KnowledgebaseService(CommonService):
|
||||
"""Service class for managing knowledge base operations.
|
||||
@ -391,12 +391,12 @@ class KnowledgebaseService(CommonService):
|
||||
"""
|
||||
# Validate name
|
||||
if not isinstance(name, str):
|
||||
return get_data_error_result(message="Dataset name must be string.")
|
||||
return {"code": RetCode.DATA_ERROR, "message": "Dataset name must be string."}
|
||||
dataset_name = name.strip()
|
||||
if dataset_name == "":
|
||||
return get_data_error_result(message="Dataset name can't be empty.")
|
||||
if len(dataset_name) == 0:
|
||||
return {"code": RetCode.DATA_ERROR, "message": "Dataset name can't be empty."}
|
||||
if len(dataset_name.encode("utf-8")) > DATASET_NAME_LIMIT:
|
||||
return get_data_error_result(message=f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}")
|
||||
return {"code": RetCode.DATA_ERROR, "message": f"Dataset name length is {len(dataset_name)} which is larger than {DATASET_NAME_LIMIT}"}
|
||||
|
||||
# Deduplicate name within tenant
|
||||
dataset_name = duplicate_name(
|
||||
@ -409,7 +409,7 @@ class KnowledgebaseService(CommonService):
|
||||
# Verify tenant exists
|
||||
ok, _t = TenantService.get_by_id(tenant_id)
|
||||
if not ok:
|
||||
return False, "Tenant not found."
|
||||
return {"code": RetCode.DATA_ERROR, "message": "Tenant does not exist."}
|
||||
|
||||
# Build payload
|
||||
kb_id = get_uuid()
|
||||
@ -419,11 +419,12 @@ class KnowledgebaseService(CommonService):
|
||||
"tenant_id": tenant_id,
|
||||
"created_by": tenant_id,
|
||||
"parser_id": (parser_id or "naive"),
|
||||
**kwargs
|
||||
**kwargs # Includes optional fields such as description, language, permission, avatar, parser_config, etc.
|
||||
}
|
||||
|
||||
# Default parser_config (align with kb_app.create) — do not accept external overrides
|
||||
# Update parser_config (always override with validated default/merged config)
|
||||
payload["parser_config"] = get_parser_config(parser_id, kwargs.get("parser_config"))
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ import re
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from functools import partial
|
||||
from typing import Generator
|
||||
from common.constants import LLMType
|
||||
from api.db.db_models import LLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
||||
@ -32,6 +33,14 @@ def get_init_tenant_llm(user_id):
|
||||
from common import settings
|
||||
tenant_llm = []
|
||||
|
||||
model_configs = {
|
||||
LLMType.CHAT: settings.CHAT_CFG,
|
||||
LLMType.EMBEDDING: settings.EMBEDDING_CFG,
|
||||
LLMType.SPEECH2TEXT: settings.ASR_CFG,
|
||||
LLMType.IMAGE2TEXT: settings.IMAGE2TEXT_CFG,
|
||||
LLMType.RERANK: settings.RERANK_CFG,
|
||||
}
|
||||
|
||||
seen = set()
|
||||
factory_configs = []
|
||||
for factory_config in [
|
||||
@ -54,8 +63,8 @@ def get_init_tenant_llm(user_id):
|
||||
"llm_factory": factory_config["factory"],
|
||||
"llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": factory_config["api_key"],
|
||||
"api_base": factory_config["base_url"],
|
||||
"api_key": model_configs.get(llm.model_type, {}).get("api_key", factory_config["api_key"]),
|
||||
"api_base": model_configs.get(llm.model_type, {}).get("base_url", factory_config["base_url"]),
|
||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||
}
|
||||
)
|
||||
@ -80,8 +89,8 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
def encode(self, texts: list):
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
|
||||
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="encode", model=self.llm_name, input={"texts": texts})
|
||||
|
||||
safe_texts = []
|
||||
for text in texts:
|
||||
token_size = num_tokens_from_string(text)
|
||||
@ -90,7 +99,7 @@ class LLMBundle(LLM4Tenant):
|
||||
safe_texts.append(text[:target_len])
|
||||
else:
|
||||
safe_texts.append(text)
|
||||
|
||||
|
||||
embeddings, used_tokens = self.mdl.encode(safe_texts)
|
||||
|
||||
llm_name = getattr(self, "llm_name", None)
|
||||
|
||||
@ -41,7 +41,7 @@ from api.db.db_models import init_database_tables as init_web_db
|
||||
from api.db.init_data import init_web_data
|
||||
from common.versions import get_ragflow_version
|
||||
from common.config_utils import show_configs
|
||||
from rag.utils.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||
from common.mcp_tool_call_conn import shutdown_all_mcp_sessions
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
@ -37,7 +37,7 @@ from peewee import OperationalError
|
||||
from common.constants import ActiveEnum
|
||||
from api.db.db_models import APIToken
|
||||
from api.utils.json_encode import CustomJSONEncoder
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService
|
||||
from common.connection_utils import timeout
|
||||
from common.constants import RetCode
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""
|
||||
Reusable HTML email templates and registry.
|
||||
"""
|
||||
|
||||
@ -164,3 +164,23 @@ def read_potential_broken_pdf(blob):
|
||||
return repaired
|
||||
|
||||
return blob
|
||||
|
||||
|
||||
def sanitize_path(raw_path: str | None) -> str:
|
||||
"""Normalize and sanitize a user-provided path segment.
|
||||
|
||||
- Converts backslashes to forward slashes
|
||||
- Strips leading/trailing slashes
|
||||
- Removes '.' and '..' segments
|
||||
- Restricts characters to A-Za-z0-9, underscore, dash, and '/'
|
||||
"""
|
||||
if not raw_path:
|
||||
return ""
|
||||
backslash_re = re.compile(r"[\\]+")
|
||||
unsafe_re = re.compile(r"[^A-Za-z0-9_\-/]")
|
||||
normalized = backslash_re.sub("/", raw_path)
|
||||
normalized = normalized.strip("/")
|
||||
parts = [seg for seg in normalized.split("/") if seg and seg not in (".", "..")]
|
||||
sanitized = "/".join(parts)
|
||||
sanitized = unsafe_re.sub("", sanitized)
|
||||
return sanitized
|
||||
|
||||
@ -173,7 +173,8 @@ def check_task_executor_alive():
|
||||
heartbeats = [json.loads(heartbeat) for heartbeat in heartbeats]
|
||||
task_executor_heartbeats[task_executor_id] = heartbeats
|
||||
if task_executor_heartbeats:
|
||||
return {"status": "alive", "message": task_executor_heartbeats}
|
||||
status = "alive" if any(task_executor_heartbeats.values()) else "timeout"
|
||||
return {"status": status, "message": task_executor_heartbeats}
|
||||
else:
|
||||
return {"status": "timeout", "message": "Not found any task executor."}
|
||||
except Exception as e:
|
||||
|
||||
@ -1,3 +1,19 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import datetime
|
||||
import json
|
||||
from enum import Enum, IntEnum
|
||||
|
||||
48
check_comment_ascii.py
Normal file
48
check_comment_ascii.py
Normal file
@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Check whether given python files contain non-ASCII comments.
|
||||
|
||||
How to check the whole git repo:
|
||||
|
||||
```
|
||||
$ git ls-files -z -- '*.py' | xargs -0 python3 check_comment_ascii.py
|
||||
```
|
||||
"""
|
||||
|
||||
import sys
|
||||
import tokenize
|
||||
import ast
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
ASCII = re.compile(r"^[\n -~]*\Z") # Printable ASCII + newline
|
||||
|
||||
|
||||
def check(src: str, name: str) -> int:
|
||||
"""
|
||||
docstring line 1
|
||||
docstring line 2
|
||||
"""
|
||||
ok = 1
|
||||
# A common comment begins with `#`
|
||||
with tokenize.open(src) as fp:
|
||||
for tk in tokenize.generate_tokens(fp.readline):
|
||||
if tk.type == tokenize.COMMENT and not ASCII.fullmatch(tk.string):
|
||||
print(f"{name}:{tk.start[0]}: non-ASCII comment: {tk.string}")
|
||||
ok = 0
|
||||
# A docstring begins and ends with `'''`
|
||||
for node in ast.walk(ast.parse(pathlib.Path(src).read_text(), filename=name)):
|
||||
if isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.Module)):
|
||||
if (doc := ast.get_docstring(node)) and not ASCII.fullmatch(doc):
|
||||
print(f"{name}:{node.lineno}: non-ASCII docstring: {doc}")
|
||||
ok = 0
|
||||
return ok
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
status = 0
|
||||
for file in sys.argv[1:]:
|
||||
if not check(file, file):
|
||||
status = 1
|
||||
sys.exit(status)
|
||||
@ -11,7 +11,7 @@ from .confluence_connector import ConfluenceConnector
|
||||
from .discord_connector import DiscordConnector
|
||||
from .dropbox_connector import DropboxConnector
|
||||
from .google_drive.connector import GoogleDriveConnector
|
||||
from .jira_connector import JiraConnector
|
||||
from .jira.connector import JiraConnector
|
||||
from .sharepoint_connector import SharePointConnector
|
||||
from .teams_connector import TeamsConnector
|
||||
from .config import BlobType, DocumentSource
|
||||
|
||||
@ -87,6 +87,13 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure")
|
||||
|
||||
elif self.bucket_type == BlobType.S3_COMPATIBLE:
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["endpoint_url", "aws_access_key_id", "aws_secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("S3 Compatible Storage")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported bucket type: {self.bucket_type}")
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ def get_current_tz_offset() -> int:
|
||||
return round(time_diff.total_seconds() / 3600)
|
||||
|
||||
|
||||
ONE_MINUTE = 60
|
||||
ONE_HOUR = 3600
|
||||
ONE_DAY = ONE_HOUR * 24
|
||||
|
||||
@ -31,6 +32,7 @@ class BlobType(str, Enum):
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
S3_COMPATIBLE = "s3_compatible"
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
@ -42,9 +44,11 @@ class DocumentSource(str, Enum):
|
||||
OCI_STORAGE = "oci_storage"
|
||||
SLACK = "slack"
|
||||
CONFLUENCE = "confluence"
|
||||
JIRA = "jira"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
GMAIL = "gmail"
|
||||
DISCORD = "discord"
|
||||
S3_COMPATIBLE = "s3_compatible"
|
||||
|
||||
|
||||
class FileOrigin(str, Enum):
|
||||
@ -178,6 +182,21 @@ GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
|
||||
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
|
||||
)
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
if ignored_tag
|
||||
]
|
||||
JIRA_CONNECTOR_MAX_TICKET_SIZE = int(
|
||||
os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024)
|
||||
)
|
||||
JIRA_SYNC_TIME_BUFFER_SECONDS = int(
|
||||
os.environ.get("JIRA_SYNC_TIME_BUFFER_SECONDS", ONE_MINUTE)
|
||||
)
|
||||
JIRA_TIMEZONE_OFFSET = float(
|
||||
os.environ.get("JIRA_TIMEZONE_OFFSET", get_current_tz_offset())
|
||||
)
|
||||
|
||||
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
||||
|
||||
@ -1788,6 +1788,7 @@ class ConfluenceConnector(
|
||||
cql_url = self.confluence_client.build_cql_url(
|
||||
page_query, expand=",".join(_PAGE_EXPANSION_FIELDS)
|
||||
)
|
||||
logging.info(f"[Confluence Connector] Building CQL URL {cql_url}")
|
||||
return update_param_in_path(cql_url, "limit", str(limit))
|
||||
|
||||
@override
|
||||
|
||||
@ -3,15 +3,9 @@ import os
|
||||
import threading
|
||||
from typing import Any, Callable
|
||||
|
||||
import requests
|
||||
|
||||
from common.data_source.config import DocumentSource
|
||||
from common.data_source.google_util.constant import GOOGLE_SCOPES
|
||||
|
||||
GOOGLE_DEVICE_CODE_URL = "https://oauth2.googleapis.com/device/code"
|
||||
GOOGLE_DEVICE_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
DEFAULT_DEVICE_INTERVAL = 5
|
||||
|
||||
|
||||
def _get_requested_scopes(source: DocumentSource) -> list[str]:
|
||||
"""Return the scopes to request, honoring an optional override env var."""
|
||||
@ -55,62 +49,6 @@ def _run_with_timeout(func: Callable[[], Any], timeout_secs: int, timeout_messag
|
||||
return result.get("value")
|
||||
|
||||
|
||||
def _extract_client_info(credentials: dict[str, Any]) -> tuple[str, str | None]:
|
||||
if "client_id" in credentials:
|
||||
return credentials["client_id"], credentials.get("client_secret")
|
||||
for key in ("installed", "web"):
|
||||
if key in credentials and isinstance(credentials[key], dict):
|
||||
nested = credentials[key]
|
||||
if "client_id" not in nested:
|
||||
break
|
||||
return nested["client_id"], nested.get("client_secret")
|
||||
raise ValueError("Provided Google OAuth credentials are missing client_id.")
|
||||
|
||||
|
||||
def start_device_authorization_flow(
|
||||
credentials: dict[str, Any],
|
||||
source: DocumentSource,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
client_id, client_secret = _extract_client_info(credentials)
|
||||
data = {
|
||||
"client_id": client_id,
|
||||
"scope": " ".join(_get_requested_scopes(source)),
|
||||
}
|
||||
if client_secret:
|
||||
data["client_secret"] = client_secret
|
||||
resp = requests.post(GOOGLE_DEVICE_CODE_URL, data=data, timeout=15)
|
||||
resp.raise_for_status()
|
||||
payload = resp.json()
|
||||
state = {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"device_code": payload.get("device_code"),
|
||||
"interval": payload.get("interval", DEFAULT_DEVICE_INTERVAL),
|
||||
}
|
||||
response_data = {
|
||||
"user_code": payload.get("user_code"),
|
||||
"verification_url": payload.get("verification_url") or payload.get("verification_uri"),
|
||||
"verification_url_complete": payload.get("verification_url_complete")
|
||||
or payload.get("verification_uri_complete"),
|
||||
"expires_in": payload.get("expires_in"),
|
||||
"interval": state["interval"],
|
||||
}
|
||||
return state, response_data
|
||||
|
||||
|
||||
def poll_device_authorization_flow(state: dict[str, Any]) -> dict[str, Any]:
|
||||
data = {
|
||||
"client_id": state["client_id"],
|
||||
"device_code": state["device_code"],
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
}
|
||||
if state.get("client_secret"):
|
||||
data["client_secret"] = state["client_secret"]
|
||||
resp = requests.post(GOOGLE_DEVICE_TOKEN_URL, data=data, timeout=20)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource) -> dict[str, Any]:
|
||||
"""Launch the standard Google OAuth local-server flow to mint user tokens."""
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
@ -125,10 +63,7 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
|
||||
preferred_port = os.environ.get("GOOGLE_OAUTH_LOCAL_SERVER_PORT")
|
||||
port = int(preferred_port) if preferred_port else 0
|
||||
timeout_secs = _get_oauth_timeout_secs()
|
||||
timeout_message = (
|
||||
f"Google OAuth verification timed out after {timeout_secs} seconds. "
|
||||
"Close any pending consent windows and rerun the connector configuration to try again."
|
||||
)
|
||||
timeout_message = f"Google OAuth verification timed out after {timeout_secs} seconds. Close any pending consent windows and rerun the connector configuration to try again."
|
||||
|
||||
print("Launching Google OAuth flow. A browser window should open shortly.")
|
||||
print("If it does not, copy the URL shown in the console into your browser manually.")
|
||||
@ -153,11 +88,8 @@ def _run_local_server_flow(client_config: dict[str, Any], source: DocumentSource
|
||||
instructions = [
|
||||
"Google rejected one or more of the requested OAuth scopes.",
|
||||
"Fix options:",
|
||||
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes "
|
||||
" (Drive metadata + Admin Directory read scopes), then re-run the flow.",
|
||||
" 1. In Google Cloud Console, open APIs & Services > OAuth consent screen and add the missing scopes (Drive metadata + Admin Directory read scopes), then re-run the flow.",
|
||||
" 2. Set GOOGLE_OAUTH_SCOPE_OVERRIDE to a comma-separated list of scopes you are allowed to request.",
|
||||
" 3. For quick local testing only, export OAUTHLIB_RELAX_TOKEN_SCOPE=1 to accept the reduced scopes "
|
||||
" (be aware the connector may lose functionality).",
|
||||
]
|
||||
raise RuntimeError("\n".join(instructions)) from warning
|
||||
raise
|
||||
@ -184,8 +116,6 @@ def ensure_oauth_token_dict(credentials: dict[str, Any], source: DocumentSource)
|
||||
client_config = {"web": credentials["web"]}
|
||||
|
||||
if client_config is None:
|
||||
raise ValueError(
|
||||
"Provided Google OAuth credentials are missing both tokens and a client configuration."
|
||||
)
|
||||
raise ValueError("Provided Google OAuth credentials are missing both tokens and a client configuration.")
|
||||
|
||||
return _run_local_server_flow(client_config, source)
|
||||
|
||||
@ -69,7 +69,7 @@ class SlimConnectorWithPermSync(ABC):
|
||||
|
||||
|
||||
class CheckpointedConnectorWithPermSync(ABC):
|
||||
"""Checkpointed connector interface (with permission sync)"""
|
||||
"""Checkpoint connector interface (with permission sync)"""
|
||||
|
||||
@abstractmethod
|
||||
def load_from_checkpoint(
|
||||
@ -143,7 +143,7 @@ class CredentialsProviderInterface(abc.ABC, Generic[T]):
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_dynamic(self) -> bool:
|
||||
"""If dynamic, the credentials may change during usage ... maening the client
|
||||
"""If dynamic, the credentials may change during usage ... meaning the client
|
||||
needs to use the locking features of the credentials provider to operate
|
||||
correctly.
|
||||
|
||||
|
||||
0
common/data_source/jira/__init__.py
Normal file
0
common/data_source/jira/__init__.py
Normal file
973
common/data_source/jira/connector.py
Normal file
973
common/data_source/jira/connector.py
Normal file
@ -0,0 +1,973 @@
|
||||
"""Checkpointed Jira connector that emits markdown blobs for each issue."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from jira import JIRA
|
||||
from jira.resources import Issue
|
||||
from pydantic import Field
|
||||
|
||||
from common.data_source.config import (
|
||||
INDEX_BATCH_SIZE,
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP,
|
||||
JIRA_CONNECTOR_MAX_TICKET_SIZE,
|
||||
JIRA_TIMEZONE_OFFSET,
|
||||
ONE_HOUR,
|
||||
DocumentSource,
|
||||
)
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorMissingCredentialError,
|
||||
ConnectorValidationError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError,
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
CheckpointedConnectorWithPermSync,
|
||||
CheckpointOutputWrapper,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync,
|
||||
)
|
||||
from common.data_source.jira.utils import (
|
||||
JIRA_CLOUD_API_VERSION,
|
||||
JIRA_SERVER_API_VERSION,
|
||||
build_issue_url,
|
||||
extract_body_text,
|
||||
extract_named_value,
|
||||
extract_user,
|
||||
format_attachments,
|
||||
format_comments,
|
||||
parse_jira_datetime,
|
||||
should_skip_issue,
|
||||
)
|
||||
from common.data_source.models import (
|
||||
ConnectorCheckpoint,
|
||||
ConnectorFailure,
|
||||
Document,
|
||||
DocumentFailure,
|
||||
SlimDocument,
|
||||
)
|
||||
from common.data_source.utils import is_atlassian_cloud_url, is_atlassian_date_error, scoped_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_FIELDS = "summary,description,updated,created,status,priority,assignee,reporter,labels,issuetype,project,comment,attachment"
|
||||
_SLIM_FIELDS = "key,project"
|
||||
_MAX_RESULTS_FETCH_IDS = 5000
|
||||
_JIRA_SLIM_PAGE_SIZE = 500
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
_DEFAULT_ATTACHMENT_SIZE_LIMIT = 10 * 1024 * 1024 # 10MB
|
||||
|
||||
|
||||
class JiraCheckpoint(ConnectorCheckpoint):
|
||||
"""Checkpoint that tracks which slice of the current JQL result set was emitted."""
|
||||
|
||||
start_at: int = 0
|
||||
cursor: str | None = None
|
||||
ids_done: bool = False
|
||||
all_issue_ids: list[list[str]] = Field(default_factory=list)
|
||||
|
||||
|
||||
_TZ_OFFSET_PATTERN = re.compile(r"([+-])(\d{2})(:?)(\d{2})$")
|
||||
|
||||
|
||||
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
||||
"""Retrieve Jira issues and emit them as markdown documents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
jira_base_url: str,
|
||||
project_key: str | None = None,
|
||||
jql_query: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
include_comments: bool = True,
|
||||
include_attachments: bool = False,
|
||||
labels_to_skip: Sequence[str] | None = None,
|
||||
comment_email_blacklist: Sequence[str] | None = None,
|
||||
scoped_token: bool = False,
|
||||
attachment_size_limit: int | None = None,
|
||||
timezone_offset: float | None = None,
|
||||
) -> None:
|
||||
if not jira_base_url:
|
||||
raise ConnectorValidationError("Jira base URL must be provided.")
|
||||
|
||||
self.jira_base_url = jira_base_url.rstrip("/")
|
||||
self.project_key = project_key
|
||||
self.jql_query = jql_query
|
||||
self.batch_size = batch_size
|
||||
self.include_comments = include_comments
|
||||
self.include_attachments = include_attachments
|
||||
configured_labels = labels_to_skip or JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
self.labels_to_skip = {label.lower() for label in configured_labels}
|
||||
self.comment_email_blacklist = {email.lower() for email in comment_email_blacklist or []}
|
||||
self.scoped_token = scoped_token
|
||||
self.jira_client: JIRA | None = None
|
||||
|
||||
self.max_ticket_size = JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
self.attachment_size_limit = attachment_size_limit if attachment_size_limit and attachment_size_limit > 0 else _DEFAULT_ATTACHMENT_SIZE_LIMIT
|
||||
self._fields_param = _DEFAULT_FIELDS
|
||||
self._slim_fields = _SLIM_FIELDS
|
||||
|
||||
tz_offset_value = float(timezone_offset) if timezone_offset is not None else float(JIRA_TIMEZONE_OFFSET)
|
||||
self.timezone_offset = tz_offset_value
|
||||
self.timezone = timezone(offset=timedelta(hours=tz_offset_value))
|
||||
self._timezone_overridden = timezone_offset is not None
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Connector lifecycle helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Instantiate the Jira client using either an API token or username/password."""
|
||||
jira_url_for_client = self.jira_base_url
|
||||
if self.scoped_token:
|
||||
if is_atlassian_cloud_url(self.jira_base_url):
|
||||
try:
|
||||
jira_url_for_client = scoped_url(self.jira_base_url, "jira")
|
||||
except ValueError as exc:
|
||||
raise ConnectorValidationError(str(exc)) from exc
|
||||
else:
|
||||
logger.warning(f"[Jira] Scoped token requested but Jira base URL {self.jira_base_url} does not appear to be an Atlassian Cloud domain; scoped token ignored.")
|
||||
|
||||
user_email = credentials.get("jira_user_email") or credentials.get("username")
|
||||
api_token = credentials.get("jira_api_token") or credentials.get("token") or credentials.get("api_token")
|
||||
password = credentials.get("jira_password") or credentials.get("password")
|
||||
rest_api_version = credentials.get("rest_api_version")
|
||||
|
||||
if not rest_api_version:
|
||||
rest_api_version = JIRA_CLOUD_API_VERSION if api_token else JIRA_SERVER_API_VERSION
|
||||
options: dict[str, Any] = {"rest_api_version": rest_api_version}
|
||||
|
||||
try:
|
||||
if user_email and api_token:
|
||||
self.jira_client = JIRA(
|
||||
server=jira_url_for_client,
|
||||
basic_auth=(user_email, api_token),
|
||||
options=options,
|
||||
)
|
||||
elif api_token:
|
||||
self.jira_client = JIRA(
|
||||
server=jira_url_for_client,
|
||||
token_auth=api_token,
|
||||
options=options,
|
||||
)
|
||||
elif user_email and password:
|
||||
self.jira_client = JIRA(
|
||||
server=jira_url_for_client,
|
||||
basic_auth=(user_email, password),
|
||||
options=options,
|
||||
)
|
||||
else:
|
||||
raise ConnectorMissingCredentialError("Jira credentials must include either an API token or username/password.")
|
||||
except Exception as exc: # pragma: no cover - jira lib raises many types
|
||||
raise ConnectorMissingCredentialError(f"Jira: {exc}") from exc
|
||||
self._sync_timezone_from_server()
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate connectivity by fetching basic Jira info."""
|
||||
if not self.jira_client:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
try:
|
||||
if self.jql_query:
|
||||
dummy_checkpoint = self.build_dummy_checkpoint()
|
||||
checkpoint_callback = self._make_checkpoint_callback(dummy_checkpoint)
|
||||
iterator = self._perform_jql_search(
|
||||
jql=self.jql_query,
|
||||
start=0,
|
||||
max_results=1,
|
||||
fields="key",
|
||||
all_issue_ids=dummy_checkpoint.all_issue_ids,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
next_page_token=dummy_checkpoint.cursor,
|
||||
ids_done=dummy_checkpoint.ids_done,
|
||||
)
|
||||
next(iter(iterator), None)
|
||||
elif self.project_key:
|
||||
self.jira_client.project(self.project_key)
|
||||
else:
|
||||
self.jira_client.projects()
|
||||
except Exception as exc: # pragma: no cover - dependent on Jira responses
|
||||
self._handle_validation_error(exc)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Checkpointed connector implementation
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: JiraCheckpoint,
|
||||
) -> Generator[Document | ConnectorFailure, None, JiraCheckpoint]:
|
||||
"""Load Jira issues, emitting a Document per issue."""
|
||||
try:
|
||||
return (yield from self._load_with_retry(start, end, checkpoint))
|
||||
except Exception as exc:
|
||||
logger.exception(f"[Jira] Jira query ultimately failed: {exc}")
|
||||
yield ConnectorFailure(
|
||||
failure_message=f"Failed to query Jira: {exc}",
|
||||
exception=exc,
|
||||
)
|
||||
return JiraCheckpoint(has_more=False, start_at=checkpoint.start_at)
|
||||
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: JiraCheckpoint,
|
||||
) -> Generator[Document | ConnectorFailure, None, JiraCheckpoint]:
|
||||
"""Permissions are not synced separately, so reuse the standard loader."""
|
||||
return (yield from self.load_from_checkpoint(start=start, end=end, checkpoint=checkpoint))
|
||||
|
||||
def _load_with_retry(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: JiraCheckpoint,
|
||||
) -> Generator[Document | ConnectorFailure, None, JiraCheckpoint]:
|
||||
if not self.jira_client:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
attempt_start = start
|
||||
retried_with_buffer = False
|
||||
attempt = 0
|
||||
|
||||
while True:
|
||||
attempt += 1
|
||||
jql = self._build_jql(attempt_start, end)
|
||||
logger.info(f"[Jira] Executing Jira JQL attempt {attempt} (start={attempt_start}, end={end}, buffered_retry={retried_with_buffer}): {jql}")
|
||||
try:
|
||||
return (yield from self._load_from_checkpoint_internal(jql, checkpoint, start_filter=start))
|
||||
except Exception as exc:
|
||||
if attempt_start is not None and not retried_with_buffer and is_atlassian_date_error(exc):
|
||||
attempt_start = attempt_start - ONE_HOUR
|
||||
retried_with_buffer = True
|
||||
logger.info(f"[Jira] Atlassian date error detected; retrying with start={attempt_start}.")
|
||||
continue
|
||||
raise
|
||||
|
||||
def _handle_validation_error(self, exc: Exception) -> None:
|
||||
status_code = getattr(exc, "status_code", None)
|
||||
if status_code == 401:
|
||||
raise InsufficientPermissionsError("Jira credential appears to be invalid or expired (HTTP 401).") from exc
|
||||
if status_code == 403:
|
||||
raise InsufficientPermissionsError("Jira token does not have permission to access the requested resources (HTTP 403).") from exc
|
||||
if status_code == 404:
|
||||
raise ConnectorValidationError("Jira resource not found (HTTP 404).") from exc
|
||||
if status_code == 429:
|
||||
raise ConnectorValidationError("Jira rate limit exceeded during validation (HTTP 429).") from exc
|
||||
|
||||
message = getattr(exc, "text", str(exc))
|
||||
if not message:
|
||||
raise UnexpectedValidationError("Unexpected Jira validation error.") from exc
|
||||
|
||||
raise ConnectorValidationError(f"Jira validation failed: {message}") from exc
|
||||
|
||||
def _load_from_checkpoint_internal(
|
||||
self,
|
||||
jql: str,
|
||||
checkpoint: JiraCheckpoint,
|
||||
start_filter: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Generator[Document | ConnectorFailure, None, JiraCheckpoint]:
|
||||
assert self.jira_client, "load_credentials must be called before loading issues."
|
||||
|
||||
page_size = self._full_page_size()
|
||||
new_checkpoint = copy.deepcopy(checkpoint)
|
||||
starting_offset = new_checkpoint.start_at or 0
|
||||
current_offset = starting_offset
|
||||
checkpoint_callback = self._make_checkpoint_callback(new_checkpoint)
|
||||
|
||||
issue_iter = self._perform_jql_search(
|
||||
jql=jql,
|
||||
start=current_offset,
|
||||
max_results=page_size,
|
||||
fields=self._fields_param,
|
||||
all_issue_ids=new_checkpoint.all_issue_ids,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
next_page_token=new_checkpoint.cursor,
|
||||
ids_done=new_checkpoint.ids_done,
|
||||
)
|
||||
|
||||
start_cutoff = float(start_filter) if start_filter is not None else None
|
||||
|
||||
for issue in issue_iter:
|
||||
current_offset += 1
|
||||
issue_key = getattr(issue, "key", "unknown")
|
||||
if should_skip_issue(issue, self.labels_to_skip):
|
||||
continue
|
||||
|
||||
issue_updated = parse_jira_datetime(issue.raw.get("fields", {}).get("updated"))
|
||||
if start_cutoff is not None and issue_updated is not None and issue_updated.timestamp() <= start_cutoff:
|
||||
# Jira JQL only supports minute precision, so we discard already-processed
|
||||
# issues here based on the original second-level cutoff.
|
||||
continue
|
||||
|
||||
try:
|
||||
document = self._issue_to_document(issue)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.exception(f"[Jira] Failed to convert Jira issue {issue_key}: {exc}")
|
||||
yield ConnectorFailure(
|
||||
failure_message=f"Failed to convert Jira issue {issue_key}: {exc}",
|
||||
failed_document=DocumentFailure(
|
||||
document_id=issue_key,
|
||||
document_link=build_issue_url(self.jira_base_url, issue_key),
|
||||
),
|
||||
exception=exc,
|
||||
)
|
||||
continue
|
||||
|
||||
if document is not None:
|
||||
yield document
|
||||
if self.include_attachments:
|
||||
for attachment_document in self._attachment_documents(issue):
|
||||
if attachment_document is not None:
|
||||
yield attachment_document
|
||||
|
||||
self._update_checkpoint_for_next_run(
|
||||
checkpoint=new_checkpoint,
|
||||
current_offset=current_offset,
|
||||
starting_offset=starting_offset,
|
||||
page_size=page_size,
|
||||
)
|
||||
new_checkpoint.start_at = current_offset
|
||||
return new_checkpoint
|
||||
|
||||
def build_dummy_checkpoint(self) -> JiraCheckpoint:
|
||||
"""Create an empty checkpoint used to kick off ingestion."""
|
||||
return JiraCheckpoint(has_more=True, start_at=0)
|
||||
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> JiraCheckpoint:
|
||||
"""Validate a serialized checkpoint."""
|
||||
return JiraCheckpoint.model_validate_json(checkpoint_json)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Slim connector implementation
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None, # noqa: ARG002 - maintained for interface compatibility
|
||||
) -> Generator[list[SlimDocument], None, None]:
|
||||
"""Return lightweight references to Jira issues (used for permission syncing)."""
|
||||
if not self.jira_client:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
start_ts = start if start is not None else 0
|
||||
end_ts = end if end is not None else datetime.now(timezone.utc).timestamp()
|
||||
jql = self._build_jql(start_ts, end_ts)
|
||||
|
||||
checkpoint = self.build_dummy_checkpoint()
|
||||
checkpoint_callback = self._make_checkpoint_callback(checkpoint)
|
||||
prev_offset = 0
|
||||
current_offset = 0
|
||||
slim_batch: list[SlimDocument] = []
|
||||
|
||||
while checkpoint.has_more:
|
||||
for issue in self._perform_jql_search(
|
||||
jql=jql,
|
||||
start=current_offset,
|
||||
max_results=_JIRA_SLIM_PAGE_SIZE,
|
||||
fields=self._slim_fields,
|
||||
all_issue_ids=checkpoint.all_issue_ids,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
next_page_token=checkpoint.cursor,
|
||||
ids_done=checkpoint.ids_done,
|
||||
):
|
||||
current_offset += 1
|
||||
if should_skip_issue(issue, self.labels_to_skip):
|
||||
continue
|
||||
|
||||
doc_id = build_issue_url(self.jira_base_url, issue.key)
|
||||
slim_batch.append(SlimDocument(id=doc_id))
|
||||
|
||||
if len(slim_batch) >= _JIRA_SLIM_PAGE_SIZE:
|
||||
yield slim_batch
|
||||
slim_batch = []
|
||||
|
||||
self._update_checkpoint_for_next_run(
|
||||
checkpoint=checkpoint,
|
||||
current_offset=current_offset,
|
||||
starting_offset=prev_offset,
|
||||
page_size=_JIRA_SLIM_PAGE_SIZE,
|
||||
)
|
||||
prev_offset = current_offset
|
||||
|
||||
if slim_batch:
|
||||
yield slim_batch
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _build_jql(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> str:
|
||||
clauses: list[str] = []
|
||||
if self.jql_query:
|
||||
clauses.append(f"({self.jql_query})")
|
||||
elif self.project_key:
|
||||
clauses.append(f'project = "{self.project_key}"')
|
||||
else:
|
||||
raise ConnectorValidationError("Either project_key or jql_query must be provided for Jira connector.")
|
||||
|
||||
if self.labels_to_skip:
|
||||
labels = ", ".join(f'"{label}"' for label in self.labels_to_skip)
|
||||
clauses.append(f"labels NOT IN ({labels})")
|
||||
|
||||
if start is not None:
|
||||
clauses.append(f'updated >= "{self._format_jql_time(start)}"')
|
||||
if end is not None:
|
||||
clauses.append(f'updated <= "{self._format_jql_time(end)}"')
|
||||
|
||||
if not clauses:
|
||||
raise ConnectorValidationError("Unable to build Jira JQL query.")
|
||||
|
||||
jql = " AND ".join(clauses)
|
||||
if "order by" not in jql.lower():
|
||||
jql = f"{jql} ORDER BY updated ASC"
|
||||
return jql
|
||||
|
||||
def _format_jql_time(self, timestamp: SecondsSinceUnixEpoch) -> str:
|
||||
dt_utc = datetime.fromtimestamp(float(timestamp), tz=timezone.utc)
|
||||
dt_local = dt_utc.astimezone(self.timezone)
|
||||
# Jira only accepts minute-precision timestamps in JQL, so we format accordingly
|
||||
# and rely on a post-query second-level filter to avoid duplicates.
|
||||
return dt_local.strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
def _issue_to_document(self, issue: Issue) -> Document | None:
|
||||
fields = issue.raw.get("fields", {})
|
||||
summary = fields.get("summary") or ""
|
||||
description_text = extract_body_text(fields.get("description"))
|
||||
comments_text = (
|
||||
format_comments(
|
||||
fields.get("comment"),
|
||||
blacklist=self.comment_email_blacklist,
|
||||
)
|
||||
if self.include_comments
|
||||
else ""
|
||||
)
|
||||
attachments_text = format_attachments(fields.get("attachment"))
|
||||
|
||||
reporter_name, reporter_email = extract_user(fields.get("reporter"))
|
||||
assignee_name, assignee_email = extract_user(fields.get("assignee"))
|
||||
status = extract_named_value(fields.get("status"))
|
||||
priority = extract_named_value(fields.get("priority"))
|
||||
issue_type = extract_named_value(fields.get("issuetype"))
|
||||
project = fields.get("project") or {}
|
||||
|
||||
issue_url = build_issue_url(self.jira_base_url, issue.key)
|
||||
|
||||
metadata_lines = [
|
||||
f"key: {issue.key}",
|
||||
f"url: {issue_url}",
|
||||
f"summary: {summary}",
|
||||
f"status: {status or 'Unknown'}",
|
||||
f"priority: {priority or 'Unspecified'}",
|
||||
f"issue_type: {issue_type or 'Unknown'}",
|
||||
f"project: {project.get('name') or ''}",
|
||||
f"project_key: {project.get('key') or self.project_key or ''}",
|
||||
]
|
||||
|
||||
if reporter_name:
|
||||
metadata_lines.append(f"reporter: {reporter_name}")
|
||||
if reporter_email:
|
||||
metadata_lines.append(f"reporter_email: {reporter_email}")
|
||||
if assignee_name:
|
||||
metadata_lines.append(f"assignee: {assignee_name}")
|
||||
if assignee_email:
|
||||
metadata_lines.append(f"assignee_email: {assignee_email}")
|
||||
if fields.get("labels"):
|
||||
metadata_lines.append(f"labels: {', '.join(fields.get('labels'))}")
|
||||
|
||||
created_dt = parse_jira_datetime(fields.get("created"))
|
||||
updated_dt = parse_jira_datetime(fields.get("updated")) or created_dt or datetime.now(timezone.utc)
|
||||
metadata_lines.append(f"created: {created_dt.isoformat() if created_dt else ''}")
|
||||
metadata_lines.append(f"updated: {updated_dt.isoformat() if updated_dt else ''}")
|
||||
|
||||
sections: list[str] = [
|
||||
"---",
|
||||
"\n".join(filter(None, metadata_lines)),
|
||||
"---",
|
||||
"",
|
||||
"## Description",
|
||||
description_text or "No description provided.",
|
||||
]
|
||||
|
||||
if comments_text:
|
||||
sections.extend(["", "## Comments", comments_text])
|
||||
if attachments_text:
|
||||
sections.extend(["", "## Attachments", attachments_text])
|
||||
|
||||
blob_text = "\n".join(sections).strip() + "\n"
|
||||
blob = blob_text.encode("utf-8")
|
||||
|
||||
if len(blob) > self.max_ticket_size:
|
||||
logger.info(f"[Jira] Skipping {issue.key} because it exceeds the maximum size of {self.max_ticket_size} bytes.")
|
||||
return None
|
||||
|
||||
semantic_identifier = f"{issue.key}: {summary}" if summary else issue.key
|
||||
|
||||
return Document(
|
||||
id=issue_url,
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=semantic_identifier,
|
||||
extension=".md",
|
||||
blob=blob,
|
||||
doc_updated_at=updated_dt,
|
||||
size_bytes=len(blob),
|
||||
)
|
||||
|
||||
def _attachment_documents(self, issue: Issue) -> Iterable[Document]:
|
||||
attachments = issue.raw.get("fields", {}).get("attachment") or []
|
||||
for attachment in attachments:
|
||||
try:
|
||||
document = self._attachment_to_document(issue, attachment)
|
||||
if document is not None:
|
||||
yield document
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
failed_id = attachment.get("id") or attachment.get("filename")
|
||||
issue_key = getattr(issue, "key", "unknown")
|
||||
logger.warning(f"[Jira] Failed to process attachment {failed_id} for issue {issue_key}: {exc}")
|
||||
|
||||
def _attachment_to_document(self, issue: Issue, attachment: dict[str, Any]) -> Document | None:
|
||||
if not self.include_attachments:
|
||||
return None
|
||||
|
||||
filename = attachment.get("filename")
|
||||
content_url = attachment.get("content")
|
||||
if not filename or not content_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
attachment_size = int(attachment.get("size", 0))
|
||||
except (TypeError, ValueError):
|
||||
attachment_size = 0
|
||||
if attachment_size and attachment_size > self.attachment_size_limit:
|
||||
logger.info(f"[Jira] Skipping attachment {filename} on {issue.key} because reported size exceeds limit ({self.attachment_size_limit} bytes).")
|
||||
return None
|
||||
|
||||
blob = self._download_attachment(content_url)
|
||||
if blob is None:
|
||||
return None
|
||||
|
||||
if len(blob) > self.attachment_size_limit:
|
||||
logger.info(f"[Jira] Skipping attachment {filename} on {issue.key} because it exceeds the size limit ({self.attachment_size_limit} bytes).")
|
||||
return None
|
||||
|
||||
attachment_time = parse_jira_datetime(attachment.get("created")) or parse_jira_datetime(attachment.get("updated"))
|
||||
updated_dt = attachment_time or parse_jira_datetime(issue.raw.get("fields", {}).get("updated")) or datetime.now(timezone.utc)
|
||||
|
||||
extension = os.path.splitext(filename)[1] or ""
|
||||
document_id = f"{issue.key}::attachment::{attachment.get('id') or filename}"
|
||||
semantic_identifier = f"{issue.key} attachment: {filename}"
|
||||
|
||||
return Document(
|
||||
id=document_id,
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=semantic_identifier,
|
||||
extension=extension,
|
||||
blob=blob,
|
||||
doc_updated_at=updated_dt,
|
||||
size_bytes=len(blob),
|
||||
)
|
||||
|
||||
def _download_attachment(self, url: str) -> bytes | None:
|
||||
if not self.jira_client:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
response = self.jira_client._session.get(url)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
def _sync_timezone_from_server(self) -> None:
|
||||
if self._timezone_overridden or not self.jira_client:
|
||||
return
|
||||
try:
|
||||
server_info = self.jira_client.server_info()
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.info(f"[Jira] Unable to determine timezone from server info; continuing with offset {self.timezone_offset}. Error: {exc}")
|
||||
return
|
||||
|
||||
detected_offset = self._extract_timezone_offset(server_info)
|
||||
if detected_offset is None or detected_offset == self.timezone_offset:
|
||||
return
|
||||
|
||||
self.timezone_offset = detected_offset
|
||||
self.timezone = timezone(offset=timedelta(hours=detected_offset))
|
||||
logger.info(f"[Jira] Timezone offset adjusted to {detected_offset} hours using Jira server info.")
|
||||
|
||||
def _extract_timezone_offset(self, server_info: dict[str, Any]) -> float | None:
|
||||
server_time_raw = server_info.get("serverTime")
|
||||
if isinstance(server_time_raw, str):
|
||||
offset = self._parse_offset_from_datetime_string(server_time_raw)
|
||||
if offset is not None:
|
||||
return offset
|
||||
|
||||
tz_name = server_info.get("timeZone")
|
||||
if isinstance(tz_name, str):
|
||||
offset = self._offset_from_zone_name(tz_name)
|
||||
if offset is not None:
|
||||
return offset
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _parse_offset_from_datetime_string(value: str) -> float | None:
|
||||
normalized = JiraConnector._normalize_datetime_string(value)
|
||||
try:
|
||||
dt = datetime.fromisoformat(normalized)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
if dt.tzinfo is None:
|
||||
return 0.0
|
||||
|
||||
offset = dt.tzinfo.utcoffset(dt)
|
||||
if offset is None:
|
||||
return None
|
||||
return offset.total_seconds() / 3600.0
|
||||
|
||||
@staticmethod
|
||||
def _normalize_datetime_string(value: str) -> str:
|
||||
trimmed = (value or "").strip()
|
||||
if trimmed.endswith("Z"):
|
||||
return f"{trimmed[:-1]}+00:00"
|
||||
|
||||
match = _TZ_OFFSET_PATTERN.search(trimmed)
|
||||
if match and match.group(3) != ":":
|
||||
sign, hours, _, minutes = match.groups()
|
||||
trimmed = f"{trimmed[: match.start()]}{sign}{hours}:{minutes}"
|
||||
return trimmed
|
||||
|
||||
@staticmethod
|
||||
def _offset_from_zone_name(name: str) -> float | None:
|
||||
try:
|
||||
tz = ZoneInfo(name)
|
||||
except (ZoneInfoNotFoundError, ValueError):
|
||||
return None
|
||||
reference = datetime.now(tz)
|
||||
offset = reference.utcoffset()
|
||||
if offset is None:
|
||||
return None
|
||||
return offset.total_seconds() / 3600.0
|
||||
|
||||
def _is_cloud_client(self) -> bool:
|
||||
if not self.jira_client:
|
||||
return False
|
||||
rest_version = str(self.jira_client._options.get("rest_api_version", "")).strip()
|
||||
return rest_version == str(JIRA_CLOUD_API_VERSION)
|
||||
|
||||
def _full_page_size(self) -> int:
|
||||
return max(1, min(self.batch_size, _JIRA_FULL_PAGE_SIZE))
|
||||
|
||||
def _perform_jql_search(
|
||||
self,
|
||||
*,
|
||||
jql: str,
|
||||
start: int,
|
||||
max_results: int,
|
||||
fields: str | None = None,
|
||||
all_issue_ids: list[list[str]] | None = None,
|
||||
checkpoint_callback: Callable[[Iterable[list[str]], str | None], None] | None = None,
|
||||
next_page_token: str | None = None,
|
||||
ids_done: bool = False,
|
||||
) -> Iterable[Issue]:
|
||||
assert self.jira_client, "Jira client not initialized."
|
||||
is_cloud = self._is_cloud_client()
|
||||
if is_cloud:
|
||||
if all_issue_ids is None:
|
||||
raise ValueError("all_issue_ids is required for Jira Cloud searches.")
|
||||
yield from self._perform_jql_search_v3(
|
||||
jql=jql,
|
||||
max_results=max_results,
|
||||
fields=fields,
|
||||
all_issue_ids=all_issue_ids,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
next_page_token=next_page_token,
|
||||
ids_done=ids_done,
|
||||
)
|
||||
else:
|
||||
yield from self._perform_jql_search_v2(
|
||||
jql=jql,
|
||||
start=start,
|
||||
max_results=max_results,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
def _perform_jql_search_v3(
|
||||
self,
|
||||
*,
|
||||
jql: str,
|
||||
max_results: int,
|
||||
all_issue_ids: list[list[str]],
|
||||
fields: str | None = None,
|
||||
checkpoint_callback: Callable[[Iterable[list[str]], str | None], None] | None = None,
|
||||
next_page_token: str | None = None,
|
||||
ids_done: bool = False,
|
||||
) -> Iterable[Issue]:
|
||||
assert self.jira_client, "Jira client not initialized."
|
||||
|
||||
if not ids_done:
|
||||
new_ids, page_token = self._enhanced_search_ids(jql, next_page_token)
|
||||
if checkpoint_callback is not None and new_ids:
|
||||
checkpoint_callback(
|
||||
self._chunk_issue_ids(new_ids, max_results),
|
||||
page_token,
|
||||
)
|
||||
elif checkpoint_callback is not None:
|
||||
checkpoint_callback([], page_token)
|
||||
|
||||
if all_issue_ids:
|
||||
issue_ids = all_issue_ids.pop()
|
||||
if issue_ids:
|
||||
yield from self._bulk_fetch_issues(issue_ids, fields)
|
||||
|
||||
def _perform_jql_search_v2(
|
||||
self,
|
||||
*,
|
||||
jql: str,
|
||||
start: int,
|
||||
max_results: int,
|
||||
fields: str | None = None,
|
||||
) -> Iterable[Issue]:
|
||||
assert self.jira_client, "Jira client not initialized."
|
||||
|
||||
issues = self.jira_client.search_issues(
|
||||
jql_str=jql,
|
||||
startAt=start,
|
||||
maxResults=max_results,
|
||||
fields=fields or self._fields_param,
|
||||
expand="renderedFields",
|
||||
)
|
||||
for issue in issues:
|
||||
yield issue
|
||||
|
||||
def _enhanced_search_ids(
|
||||
self,
|
||||
jql: str,
|
||||
next_page_token: str | None,
|
||||
) -> tuple[list[str], str | None]:
|
||||
assert self.jira_client, "Jira client not initialized."
|
||||
enhanced_search_path = self.jira_client._get_url("search/jql")
|
||||
params: dict[str, str | int | None] = {
|
||||
"jql": jql,
|
||||
"maxResults": _MAX_RESULTS_FETCH_IDS,
|
||||
"nextPageToken": next_page_token,
|
||||
"fields": "id",
|
||||
}
|
||||
response = self.jira_client._session.get(enhanced_search_path, params=params)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return [str(issue["id"]) for issue in data.get("issues", [])], data.get("nextPageToken")
|
||||
|
||||
def _bulk_fetch_issues(
|
||||
self,
|
||||
issue_ids: list[str],
|
||||
fields: str | None,
|
||||
) -> Iterable[Issue]:
|
||||
assert self.jira_client, "Jira client not initialized."
|
||||
if not issue_ids:
|
||||
return []
|
||||
|
||||
bulk_fetch_path = self.jira_client._get_url("issue/bulkfetch")
|
||||
payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids}
|
||||
payload["fields"] = fields.split(",") if fields else ["*all"]
|
||||
|
||||
response = self.jira_client._session.post(bulk_fetch_path, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return [Issue(self.jira_client._options, self.jira_client._session, raw=issue) for issue in data.get("issues", [])]
|
||||
|
||||
@staticmethod
|
||||
def _chunk_issue_ids(issue_ids: list[str], chunk_size: int) -> Iterable[list[str]]:
|
||||
if chunk_size <= 0:
|
||||
chunk_size = _JIRA_FULL_PAGE_SIZE
|
||||
|
||||
for idx in range(0, len(issue_ids), chunk_size):
|
||||
yield issue_ids[idx : idx + chunk_size]
|
||||
|
||||
def _make_checkpoint_callback(self, checkpoint: JiraCheckpoint) -> Callable[[Iterable[list[str]], str | None], None]:
|
||||
def checkpoint_callback(
|
||||
issue_ids: Iterable[list[str]] | list[list[str]],
|
||||
page_token: str | None,
|
||||
) -> None:
|
||||
for id_batch in issue_ids:
|
||||
checkpoint.all_issue_ids.append(list(id_batch))
|
||||
checkpoint.cursor = page_token
|
||||
checkpoint.ids_done = page_token is None
|
||||
|
||||
return checkpoint_callback
|
||||
|
||||
def _update_checkpoint_for_next_run(
|
||||
self,
|
||||
*,
|
||||
checkpoint: JiraCheckpoint,
|
||||
current_offset: int,
|
||||
starting_offset: int,
|
||||
page_size: int,
|
||||
) -> None:
|
||||
if self._is_cloud_client():
|
||||
checkpoint.has_more = bool(checkpoint.all_issue_ids) or not checkpoint.ids_done
|
||||
else:
|
||||
checkpoint.has_more = current_offset - starting_offset == page_size
|
||||
checkpoint.cursor = None
|
||||
checkpoint.ids_done = True
|
||||
checkpoint.all_issue_ids = []
|
||||
|
||||
|
||||
def iterate_jira_documents(
|
||||
connector: "JiraConnector",
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
iteration_limit: int = 100_000,
|
||||
) -> Iterator[Document]:
|
||||
"""Yield documents without materializing the entire result set."""
|
||||
|
||||
checkpoint = connector.build_dummy_checkpoint()
|
||||
iterations = 0
|
||||
|
||||
while checkpoint.has_more:
|
||||
wrapper = CheckpointOutputWrapper[JiraCheckpoint]()
|
||||
generator = wrapper(connector.load_from_checkpoint(start=start, end=end, checkpoint=checkpoint))
|
||||
|
||||
for document, failure, next_checkpoint in generator:
|
||||
if failure is not None:
|
||||
failure_message = getattr(failure, "failure_message", str(failure))
|
||||
raise RuntimeError(f"Failed to load Jira documents: {failure_message}")
|
||||
if document is not None:
|
||||
yield document
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
iterations += 1
|
||||
if iterations > iteration_limit:
|
||||
raise RuntimeError("Too many iterations while loading Jira documents.")
|
||||
|
||||
|
||||
def test_jira(
|
||||
*,
|
||||
base_url: str,
|
||||
project_key: str | None = None,
|
||||
jql_query: str | None = None,
|
||||
credentials: dict[str, Any],
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
start_ts: float | None = None,
|
||||
end_ts: float | None = None,
|
||||
connector_options: dict[str, Any] | None = None,
|
||||
) -> list[Document]:
|
||||
"""Programmatic entry point that mirrors the CLI workflow."""
|
||||
|
||||
connector_kwargs = connector_options.copy() if connector_options else {}
|
||||
connector = JiraConnector(
|
||||
jira_base_url=base_url,
|
||||
project_key=project_key,
|
||||
jql_query=jql_query,
|
||||
batch_size=batch_size,
|
||||
**connector_kwargs,
|
||||
)
|
||||
connector.load_credentials(credentials)
|
||||
connector.validate_connector_settings()
|
||||
|
||||
now_ts = datetime.now(timezone.utc).timestamp()
|
||||
start = start_ts if start_ts is not None else 0.0
|
||||
end = end_ts if end_ts is not None else now_ts
|
||||
|
||||
documents = list(iterate_jira_documents(connector, start=start, end=end))
|
||||
logger.info(f"[Jira] Fetched {len(documents)} Jira documents.")
|
||||
for doc in documents[:5]:
|
||||
logger.info(f"[Jira] Document {doc.semantic_identifier} ({doc.id}) size={doc.size_bytes} bytes")
|
||||
return documents
|
||||
|
||||
|
||||
def _build_arg_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="Fetch Jira issues and print summary statistics.")
|
||||
parser.add_argument("--base-url", dest="base_url", default=os.environ.get("JIRA_BASE_URL"))
|
||||
parser.add_argument("--project", dest="project_key", default=os.environ.get("JIRA_PROJECT_KEY"))
|
||||
parser.add_argument("--jql", dest="jql_query", default=os.environ.get("JIRA_JQL"))
|
||||
parser.add_argument("--email", dest="user_email", default=os.environ.get("JIRA_USER_EMAIL"))
|
||||
parser.add_argument("--token", dest="api_token", default=os.environ.get("JIRA_API_TOKEN"))
|
||||
parser.add_argument("--password", dest="password", default=os.environ.get("JIRA_PASSWORD"))
|
||||
parser.add_argument("--batch-size", dest="batch_size", type=int, default=int(os.environ.get("JIRA_BATCH_SIZE", INDEX_BATCH_SIZE)))
|
||||
parser.add_argument("--include_comments", dest="include_comments", type=bool, default=True)
|
||||
parser.add_argument("--include_attachments", dest="include_attachments", type=bool, default=True)
|
||||
parser.add_argument("--attachment_size_limit", dest="attachment_size_limit", type=float, default=_DEFAULT_ATTACHMENT_SIZE_LIMIT)
|
||||
parser.add_argument("--start-ts", dest="start_ts", type=float, default=None, help="Epoch seconds inclusive lower bound for updated issues.")
|
||||
parser.add_argument("--end-ts", dest="end_ts", type=float, default=9999999999, help="Epoch seconds inclusive upper bound for updated issues.")
|
||||
return parser
|
||||
|
||||
|
||||
def main(config: dict[str, Any] | None = None) -> None:
|
||||
if config is None:
|
||||
args = _build_arg_parser().parse_args()
|
||||
config = {
|
||||
"base_url": args.base_url,
|
||||
"project_key": args.project_key,
|
||||
"jql_query": args.jql_query,
|
||||
"batch_size": args.batch_size,
|
||||
"start_ts": args.start_ts,
|
||||
"end_ts": args.end_ts,
|
||||
"include_comments": args.include_comments,
|
||||
"include_attachments": args.include_attachments,
|
||||
"attachment_size_limit": args.attachment_size_limit,
|
||||
"credentials": {
|
||||
"jira_user_email": args.user_email,
|
||||
"jira_api_token": args.api_token,
|
||||
"jira_password": args.password,
|
||||
},
|
||||
}
|
||||
|
||||
base_url = config.get("base_url")
|
||||
credentials = config.get("credentials", {})
|
||||
|
||||
print(f"[Jira] {config=}", flush=True)
|
||||
print(f"[Jira] {credentials=}", flush=True)
|
||||
|
||||
if not base_url:
|
||||
raise RuntimeError("Jira base URL must be provided via config or CLI arguments.")
|
||||
if not (credentials.get("jira_api_token") or (credentials.get("jira_user_email") and credentials.get("jira_password"))):
|
||||
raise RuntimeError("Provide either an API token or both email/password for Jira authentication.")
|
||||
|
||||
connector_options = {
|
||||
key: value
|
||||
for key, value in (
|
||||
("include_comments", config.get("include_comments")),
|
||||
("include_attachments", config.get("include_attachments")),
|
||||
("attachment_size_limit", config.get("attachment_size_limit")),
|
||||
("labels_to_skip", config.get("labels_to_skip")),
|
||||
("comment_email_blacklist", config.get("comment_email_blacklist")),
|
||||
("scoped_token", config.get("scoped_token")),
|
||||
("timezone_offset", config.get("timezone_offset")),
|
||||
)
|
||||
if value is not None
|
||||
}
|
||||
|
||||
documents = test_jira(
|
||||
base_url=base_url,
|
||||
project_key=config.get("project_key"),
|
||||
jql_query=config.get("jql_query"),
|
||||
credentials=credentials,
|
||||
batch_size=config.get("batch_size", INDEX_BATCH_SIZE),
|
||||
start_ts=config.get("start_ts"),
|
||||
end_ts=config.get("end_ts"),
|
||||
connector_options=connector_options,
|
||||
)
|
||||
|
||||
preview_count = min(len(documents), 5)
|
||||
for idx in range(preview_count):
|
||||
doc = documents[idx]
|
||||
print(f"[Jira] [Sample {idx + 1}] {doc.semantic_identifier} | id={doc.id} | size={doc.size_bytes} bytes")
|
||||
|
||||
print(f"[Jira] Jira connector test completed. Documents fetched: {len(documents)}")
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover - manual execution path
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
main()
|
||||
149
common/data_source/jira/utils.py
Normal file
149
common/data_source/jira/utils.py
Normal file
@ -0,0 +1,149 @@
|
||||
"""Helper utilities for the Jira connector."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections.abc import Collection
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Iterable
|
||||
|
||||
from jira.resources import Issue
|
||||
|
||||
from common.data_source.utils import datetime_from_string
|
||||
|
||||
JIRA_SERVER_API_VERSION = os.environ.get("JIRA_SERVER_API_VERSION", "2")
|
||||
JIRA_CLOUD_API_VERSION = os.environ.get("JIRA_CLOUD_API_VERSION", "3")
|
||||
|
||||
|
||||
def build_issue_url(base_url: str, issue_key: str) -> str:
|
||||
"""Return the canonical UI URL for a Jira issue."""
|
||||
return f"{base_url.rstrip('/')}/browse/{issue_key}"
|
||||
|
||||
|
||||
def parse_jira_datetime(value: Any) -> datetime | None:
|
||||
"""Best-effort parse of Jira datetime values to aware UTC datetimes."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, datetime):
|
||||
return value.astimezone(timezone.utc) if value.tzinfo else value.replace(tzinfo=timezone.utc)
|
||||
if isinstance(value, str):
|
||||
return datetime_from_string(value)
|
||||
return None
|
||||
|
||||
|
||||
def extract_named_value(value: Any) -> str | None:
|
||||
"""Extract a readable string out of Jira's typed objects."""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
return value.get("name") or value.get("value")
|
||||
return getattr(value, "name", None)
|
||||
|
||||
|
||||
def extract_user(value: Any) -> tuple[str | None, str | None]:
|
||||
"""Return display name + email tuple for a Jira user blob."""
|
||||
if value is None:
|
||||
return None, None
|
||||
if isinstance(value, dict):
|
||||
return value.get("displayName"), value.get("emailAddress")
|
||||
|
||||
display = getattr(value, "displayName", None)
|
||||
email = getattr(value, "emailAddress", None)
|
||||
return display, email
|
||||
|
||||
|
||||
def extract_text_from_adf(adf: Any) -> str:
|
||||
"""Flatten Atlassian Document Format (ADF) structures to text."""
|
||||
texts: list[str] = []
|
||||
|
||||
def _walk(node: Any) -> None:
|
||||
if node is None:
|
||||
return
|
||||
if isinstance(node, dict):
|
||||
node_type = node.get("type")
|
||||
if node_type == "text":
|
||||
texts.append(node.get("text", ""))
|
||||
for child in node.get("content", []):
|
||||
_walk(child)
|
||||
elif isinstance(node, list):
|
||||
for child in node:
|
||||
_walk(child)
|
||||
|
||||
_walk(adf)
|
||||
return "\n".join(part for part in texts if part)
|
||||
|
||||
|
||||
def extract_body_text(value: Any) -> str:
|
||||
"""Normalize Jira description/comments (raw/adf/str) into plain text."""
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value.strip()
|
||||
if isinstance(value, dict):
|
||||
return extract_text_from_adf(value).strip()
|
||||
return str(value).strip()
|
||||
|
||||
|
||||
def format_comments(
|
||||
comment_block: Any,
|
||||
*,
|
||||
blacklist: Collection[str],
|
||||
) -> str:
|
||||
"""Convert Jira comments into a markdown-ish bullet list."""
|
||||
if not isinstance(comment_block, dict):
|
||||
return ""
|
||||
|
||||
comments = comment_block.get("comments") or []
|
||||
lines: list[str] = []
|
||||
normalized_blacklist = {email.lower() for email in blacklist if email}
|
||||
|
||||
for comment in comments:
|
||||
author = comment.get("author") or {}
|
||||
author_email = (author.get("emailAddress") or "").lower()
|
||||
if author_email and author_email in normalized_blacklist:
|
||||
continue
|
||||
|
||||
author_name = author.get("displayName") or author.get("name") or author_email or "Unknown"
|
||||
created = parse_jira_datetime(comment.get("created"))
|
||||
created_str = created.isoformat() if created else "Unknown time"
|
||||
body = extract_body_text(comment.get("body"))
|
||||
if not body:
|
||||
continue
|
||||
|
||||
lines.append(f"- {author_name} ({created_str}):\n{body}")
|
||||
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
def format_attachments(attachments: Any) -> str:
|
||||
"""List Jira attachments as bullet points."""
|
||||
if not isinstance(attachments, list):
|
||||
return ""
|
||||
|
||||
attachment_lines: list[str] = []
|
||||
for attachment in attachments:
|
||||
filename = attachment.get("filename")
|
||||
if not filename:
|
||||
continue
|
||||
size = attachment.get("size")
|
||||
size_text = f" ({size} bytes)" if isinstance(size, int) else ""
|
||||
content_url = attachment.get("content") or ""
|
||||
url_suffix = f" -> {content_url}" if content_url else ""
|
||||
attachment_lines.append(f"- {filename}{size_text}{url_suffix}")
|
||||
|
||||
return "\n".join(attachment_lines)
|
||||
|
||||
|
||||
def should_skip_issue(issue: Issue, labels_to_skip: set[str]) -> bool:
|
||||
"""Return True if the issue contains any label from the skip list."""
|
||||
if not labels_to_skip:
|
||||
return False
|
||||
|
||||
fields = getattr(issue, "raw", {}).get("fields", {})
|
||||
labels: Iterable[str] = fields.get("labels") or []
|
||||
for label in labels:
|
||||
if (label or "").lower() in labels_to_skip:
|
||||
return True
|
||||
return False
|
||||
@ -1,112 +0,0 @@
|
||||
"""Jira connector"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from jira import JIRA
|
||||
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.exceptions import (
|
||||
ConnectorValidationError,
|
||||
InsufficientPermissionsError,
|
||||
UnexpectedValidationError, ConnectorMissingCredentialError
|
||||
)
|
||||
from common.data_source.interfaces import (
|
||||
CheckpointedConnectorWithPermSync,
|
||||
SecondsSinceUnixEpoch,
|
||||
SlimConnectorWithPermSync
|
||||
)
|
||||
from common.data_source.models import (
|
||||
ConnectorCheckpoint
|
||||
)
|
||||
|
||||
|
||||
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
||||
"""Jira connector for accessing Jira issues and projects"""
|
||||
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.jira_client: JIRA | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Load Jira credentials"""
|
||||
try:
|
||||
url = credentials.get("url")
|
||||
username = credentials.get("username")
|
||||
password = credentials.get("password")
|
||||
token = credentials.get("token")
|
||||
|
||||
if not url:
|
||||
raise ConnectorMissingCredentialError("Jira URL is required")
|
||||
|
||||
if token:
|
||||
# API token authentication
|
||||
self.jira_client = JIRA(server=url, token_auth=token)
|
||||
elif username and password:
|
||||
# Basic authentication
|
||||
self.jira_client = JIRA(server=url, basic_auth=(username, password))
|
||||
else:
|
||||
raise ConnectorMissingCredentialError("Jira credentials are incomplete")
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(f"Jira: {e}")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""Validate Jira connector settings"""
|
||||
if not self.jira_client:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
try:
|
||||
# Test connection by getting server info
|
||||
self.jira_client.server_info()
|
||||
except Exception as e:
|
||||
if "401" in str(e) or "403" in str(e):
|
||||
raise InsufficientPermissionsError("Invalid credentials or insufficient permissions")
|
||||
elif "404" in str(e):
|
||||
raise ConnectorValidationError("Jira instance not found")
|
||||
else:
|
||||
raise UnexpectedValidationError(f"Jira validation error: {e}")
|
||||
|
||||
def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any:
|
||||
"""Poll Jira for recent issues"""
|
||||
# Simplified implementation - in production this would handle actual polling
|
||||
return []
|
||||
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Any:
|
||||
"""Load documents from checkpoint"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: ConnectorCheckpoint,
|
||||
) -> Any:
|
||||
"""Load documents from checkpoint with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
|
||||
def build_dummy_checkpoint(self) -> ConnectorCheckpoint:
|
||||
"""Build dummy checkpoint"""
|
||||
return ConnectorCheckpoint()
|
||||
|
||||
def validate_checkpoint_json(self, checkpoint_json: str) -> ConnectorCheckpoint:
|
||||
"""Validate checkpoint JSON"""
|
||||
# Simplified implementation
|
||||
return ConnectorCheckpoint()
|
||||
|
||||
def retrieve_all_slim_docs_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> Any:
|
||||
"""Retrieve all simplified documents with permission sync"""
|
||||
# Simplified implementation
|
||||
return []
|
||||
@ -48,17 +48,35 @@ from common.data_source.exceptions import RateLimitTriedTooManyTimesError
|
||||
from common.data_source.interfaces import CT, CheckpointedConnector, CheckpointOutputWrapper, ConfluenceUser, LoadFunction, OnyxExtensionType, SecondsSinceUnixEpoch, TokenResponse
|
||||
from common.data_source.models import BasicExpertInfo, Document
|
||||
|
||||
_TZ_SUFFIX_PATTERN = re.compile(r"([+-])([\d:]+)$")
|
||||
|
||||
|
||||
def datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_string = datetime_string.strip()
|
||||
|
||||
match_jira_format = _TZ_SUFFIX_PATTERN.search(datetime_string)
|
||||
if match_jira_format:
|
||||
sign, tz_field = match_jira_format.groups()
|
||||
digits = tz_field.replace(":", "")
|
||||
|
||||
if digits.isdigit() and 1 <= len(digits) <= 4:
|
||||
if len(digits) >= 3:
|
||||
hours = digits[:-2].rjust(2, "0")
|
||||
minutes = digits[-2:]
|
||||
else:
|
||||
hours = digits.rjust(2, "0")
|
||||
minutes = "00"
|
||||
|
||||
normalized = f"{sign}{hours}:{minutes}"
|
||||
datetime_string = f"{datetime_string[: match_jira_format.start()]}{normalized}"
|
||||
|
||||
# Handle the case where the datetime string ends with 'Z' (Zulu time)
|
||||
if datetime_string.endswith('Z'):
|
||||
datetime_string = datetime_string[:-1] + '+00:00'
|
||||
if datetime_string.endswith("Z"):
|
||||
datetime_string = datetime_string[:-1] + "+00:00"
|
||||
|
||||
# Handle timezone format "+0000" -> "+00:00"
|
||||
if datetime_string.endswith('+0000'):
|
||||
datetime_string = datetime_string[:-5] + '+00:00'
|
||||
if datetime_string.endswith("+0000"):
|
||||
datetime_string = datetime_string[:-5] + "+00:00"
|
||||
|
||||
datetime_object = datetime.fromisoformat(datetime_string)
|
||||
|
||||
@ -293,6 +311,13 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
|
||||
aws_secret_access_key=credentials["secret_access_key"],
|
||||
region_name=credentials["region"],
|
||||
)
|
||||
elif bucket_type == BlobType.S3_COMPATIBLE:
|
||||
return boto3.client(
|
||||
"s3",
|
||||
endpoint_url=credentials["endpoint_url"],
|
||||
aws_access_key_id=credentials["aws_access_key_id"],
|
||||
aws_secret_access_key=credentials["aws_secret_access_key"],
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported bucket type: {bucket_type}")
|
||||
@ -480,7 +505,7 @@ def get_file_ext(file_name: str) -> str:
|
||||
|
||||
|
||||
def is_accepted_file_ext(file_ext: str, extension_type: OnyxExtensionType) -> bool:
|
||||
image_extensions = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'}
|
||||
image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp"}
|
||||
text_extensions = {".txt", ".md", ".mdx", ".conf", ".log", ".json", ".csv", ".tsv", ".xml", ".yml", ".yaml", ".sql"}
|
||||
document_extensions = {".pdf", ".docx", ".pptx", ".xlsx", ".eml", ".epub", ".html"}
|
||||
|
||||
@ -902,6 +927,18 @@ def load_all_docs_from_checkpoint_connector(
|
||||
)
|
||||
|
||||
|
||||
_ATLASSIAN_CLOUD_DOMAINS = (".atlassian.net", ".jira.com", ".jira-dev.com")
|
||||
|
||||
|
||||
def is_atlassian_cloud_url(url: str) -> bool:
|
||||
try:
|
||||
host = urlparse(url).hostname or ""
|
||||
except ValueError:
|
||||
return False
|
||||
host = host.lower()
|
||||
return any(host.endswith(domain) for domain in _ATLASSIAN_CLOUD_DOMAINS)
|
||||
|
||||
|
||||
def get_cloudId(base_url: str) -> str:
|
||||
tenant_info_url = urljoin(base_url, "/_edge/tenant_info")
|
||||
response = requests.get(tenant_info_url, timeout=10)
|
||||
|
||||
@ -80,4 +80,4 @@ def log_exception(e, *args):
|
||||
raise Exception(a.text)
|
||||
else:
|
||||
logging.error(str(a))
|
||||
raise e
|
||||
raise e
|
||||
|
||||
@ -21,7 +21,7 @@ import weakref
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent.futures import TimeoutError as FuturesTimeoutError
|
||||
from string import Template
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@ -30,12 +30,15 @@ from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
|
||||
from rag.llm.chat_model import ToolCallSession
|
||||
|
||||
MCPTaskType = Literal["list_tools", "tool_call"]
|
||||
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
|
||||
|
||||
|
||||
class ToolCallSession(Protocol):
|
||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
|
||||
|
||||
|
||||
class MCPToolCallSession(ToolCallSession):
|
||||
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
|
||||
|
||||
@ -106,7 +109,8 @@ class MCPToolCallSession(ToolCallSession):
|
||||
await self._process_mcp_tasks(None, msg)
|
||||
|
||||
else:
|
||||
await self._process_mcp_tasks(None, f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
|
||||
await self._process_mcp_tasks(None,
|
||||
f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}")
|
||||
|
||||
async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None:
|
||||
while not self._close:
|
||||
@ -164,7 +168,8 @@ class MCPToolCallSession(ToolCallSession):
|
||||
raise
|
||||
|
||||
async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str:
|
||||
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, timeout=timeout)
|
||||
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments,
|
||||
timeout=timeout)
|
||||
|
||||
if result.isError:
|
||||
return f"MCP server error: {result.content}"
|
||||
@ -283,7 +288,8 @@ def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) ->
|
||||
except Exception:
|
||||
logging.exception("Exception during MCP session cleanup thread management")
|
||||
|
||||
logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
|
||||
logging.info(
|
||||
f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
|
||||
|
||||
|
||||
def shutdown_all_mcp_sessions():
|
||||
@ -298,7 +304,7 @@ def shutdown_all_mcp_sessions():
|
||||
logging.info("All MCPToolCallSession instances have been closed.")
|
||||
|
||||
|
||||
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool|dict) -> dict[str, Any]:
|
||||
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict) -> dict[str, Any]:
|
||||
if isinstance(mcp_tool, dict):
|
||||
return {
|
||||
"type": "function",
|
||||
@ -4839,6 +4839,639 @@
|
||||
"is_tools": false
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "JieKou.AI",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING,TEXT RE-RANK",
|
||||
"status": "1",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "Sao10K/L3-8B-Stheno-v3.2",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "baichuan/baichuan-m2-32b",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "baidu/ernie-4.5-300b-a47b-paddle",
|
||||
"tags": "LLM,CHAT,123K",
|
||||
"max_tokens": 123000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "baidu/ernie-4.5-vl-424b-a47b",
|
||||
"tags": "LLM,CHAT,123K",
|
||||
"max_tokens": 123000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "claude-3-5-haiku-20241022",
|
||||
"tags": "LLM,CHAT,200K",
|
||||
"max_tokens": 200000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "claude-3-5-sonnet-20241022",
|
||||
"tags": "LLM,CHAT,200K",
|
||||
"max_tokens": 200000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "claude-3-7-sonnet-20250219",
|
||||
"tags": "LLM,CHAT,200K",
|
||||
"max_tokens": 200000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "claude-3-haiku-20240307",
|
||||
"tags": "LLM,CHAT,200K",
|
||||
"max_tokens": 200000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "claude-haiku-4-5-20251001",
|
||||
"tags": "LLM,CHAT,20K,IMAGE2TEXT",
|
||||
"max_tokens": 20000,
|
||||
"model_type": "image2text",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "claude-opus-4-1-20250805",
|
||||
"tags": "LLM,CHAT,200K",
|
||||
"max_tokens": 200000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "claude-opus-4-20250514",
|
||||
"tags": "LLM,CHAT,200K",
|
||||
"max_tokens": 200000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "claude-sonnet-4-20250514",
|
||||
"tags": "LLM,CHAT,200K",
|
||||
"max_tokens": 200000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "claude-sonnet-4-5-20250929",
|
||||
"tags": "LLM,CHAT,200K,IMAGE2TEXT",
|
||||
"max_tokens": 200000,
|
||||
"model_type": "image2text",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "deepseek/deepseek-r1-0528",
|
||||
"tags": "LLM,CHAT,163K",
|
||||
"max_tokens": 163840,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "deepseek/deepseek-v3-0324",
|
||||
"tags": "LLM,CHAT,163K",
|
||||
"max_tokens": 163840,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "deepseek/deepseek-v3.1",
|
||||
"tags": "LLM,CHAT,163K",
|
||||
"max_tokens": 163840,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "doubao-1-5-pro-32k-250115",
|
||||
"tags": "LLM,CHAT,128K",
|
||||
"max_tokens": 128000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "doubao-1.5-pro-32k-character-250715",
|
||||
"tags": "LLM,CHAT,200K",
|
||||
"max_tokens": 200000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gemini-2.0-flash-20250609",
|
||||
"tags": "LLM,CHAT,1M",
|
||||
"max_tokens": 1048576,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gemini-2.0-flash-lite",
|
||||
"tags": "LLM,CHAT,1M",
|
||||
"max_tokens": 1048576,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gemini-2.5-flash",
|
||||
"tags": "LLM,CHAT,1M",
|
||||
"max_tokens": 1048576,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gemini-2.5-flash-lite",
|
||||
"tags": "LLM,CHAT,1M",
|
||||
"max_tokens": 1048576,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gemini-2.5-flash-lite-preview-06-17",
|
||||
"tags": "LLM,CHAT,1M",
|
||||
"max_tokens": 1048576,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gemini-2.5-flash-lite-preview-09-2025",
|
||||
"tags": "LLM,CHAT,1M,IMAGE2TEXT",
|
||||
"max_tokens": 1048576,
|
||||
"model_type": "image2text",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gemini-2.5-flash-preview-05-20",
|
||||
"tags": "LLM,CHAT,1M",
|
||||
"max_tokens": 1048576,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gemini-2.5-pro",
|
||||
"tags": "LLM,CHAT,1M",
|
||||
"max_tokens": 1048576,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gemini-2.5-pro-preview-06-05",
|
||||
"tags": "LLM,CHAT,1M",
|
||||
"max_tokens": 1048576,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "google/gemma-3-12b-it",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "google/gemma-3-27b-it",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-4.1",
|
||||
"tags": "LLM,CHAT,1M",
|
||||
"max_tokens": 1047576,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-4.1-mini",
|
||||
"tags": "LLM,CHAT,1M",
|
||||
"max_tokens": 1047576,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-4.1-nano",
|
||||
"tags": "LLM,CHAT,1M",
|
||||
"max_tokens": 1047576,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-4o",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-4o-mini",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-5",
|
||||
"tags": "LLM,CHAT,400K",
|
||||
"max_tokens": 400000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-5-chat-latest",
|
||||
"tags": "LLM,CHAT,400K",
|
||||
"max_tokens": 400000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-5-codex",
|
||||
"tags": "LLM,CHAT,400K,IMAGE2TEXT",
|
||||
"max_tokens": 400000,
|
||||
"model_type": "image2text",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-5-mini",
|
||||
"tags": "LLM,CHAT,400K",
|
||||
"max_tokens": 400000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-5-nano",
|
||||
"tags": "LLM,CHAT,400K",
|
||||
"max_tokens": 400000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-5-pro",
|
||||
"tags": "LLM,CHAT,400K,IMAGE2TEXT",
|
||||
"max_tokens": 400000,
|
||||
"model_type": "image2text",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-5.1",
|
||||
"tags": "LLM,CHAT,400K",
|
||||
"max_tokens": 400000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-5.1-chat-latest",
|
||||
"tags": "LLM,CHAT,128K",
|
||||
"max_tokens": 128000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gpt-5.1-codex",
|
||||
"tags": "LLM,CHAT,400K",
|
||||
"max_tokens": 400000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "grok-3",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "grok-3-mini",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "grok-4-0709",
|
||||
"tags": "LLM,CHAT,256K",
|
||||
"max_tokens": 256000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "grok-4-fast-non-reasoning",
|
||||
"tags": "LLM,CHAT,2M,IMAGE2TEXT",
|
||||
"max_tokens": 2000000,
|
||||
"model_type": "image2text",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "grok-4-fast-reasoning",
|
||||
"tags": "LLM,CHAT,2M,IMAGE2TEXT",
|
||||
"max_tokens": 2000000,
|
||||
"model_type": "image2text",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "grok-code-fast-1",
|
||||
"tags": "LLM,CHAT,256K",
|
||||
"max_tokens": 256000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "gryphe/mythomax-l2-13b",
|
||||
"tags": "LLM,CHAT,4K",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat",
|
||||
"is_tools": false
|
||||
},
|
||||
{
|
||||
"llm_name": "meta-llama/llama-3.1-8b-instruct",
|
||||
"tags": "LLM,CHAT,16K",
|
||||
"max_tokens": 16384,
|
||||
"model_type": "chat",
|
||||
"is_tools": false
|
||||
},
|
||||
{
|
||||
"llm_name": "meta-llama/llama-3.2-3b-instruct",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "meta-llama/llama-3.3-70b-instruct",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "meta-llama/llama-4-maverick-17b-128e-instruct-fp8",
|
||||
"tags": "LLM,CHAT,1M",
|
||||
"max_tokens": 1048576,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "minimaxai/minimax-m1-80k",
|
||||
"tags": "LLM,CHAT,1M",
|
||||
"max_tokens": 1000000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "mistralai/mistral-7b-instruct",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat",
|
||||
"is_tools": false
|
||||
},
|
||||
{
|
||||
"llm_name": "mistralai/mistral-nemo",
|
||||
"tags": "LLM,CHAT,60K",
|
||||
"max_tokens": 60288,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "moonshotai/kimi-k2-0905",
|
||||
"tags": "LLM,CHAT,262K",
|
||||
"max_tokens": 262144,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "moonshotai/kimi-k2-instruct",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "o1",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "o1-mini",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "o3",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "o3-mini",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/gpt-oss-120b",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "openai/gpt-oss-20b",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen-2.5-72b-instruct",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen-mt-plus",
|
||||
"tags": "LLM,CHAT,4K",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat",
|
||||
"is_tools": false
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen2.5-7b-instruct",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen2.5-vl-72b-instruct",
|
||||
"tags": "LLM,CHAT,32K",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat",
|
||||
"is_tools": false
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen3-235b-a22b-fp8",
|
||||
"tags": "LLM,CHAT,40K",
|
||||
"max_tokens": 40960,
|
||||
"model_type": "chat",
|
||||
"is_tools": false
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen3-235b-a22b-instruct-2507",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen3-235b-a22b-thinking-2507",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen3-30b-a3b-fp8",
|
||||
"tags": "LLM,CHAT,40K",
|
||||
"max_tokens": 40960,
|
||||
"model_type": "chat",
|
||||
"is_tools": false
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen3-32b-fp8",
|
||||
"tags": "LLM,CHAT,40K",
|
||||
"max_tokens": 40960,
|
||||
"model_type": "chat",
|
||||
"is_tools": false
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen3-8b-fp8",
|
||||
"tags": "LLM,CHAT,128K",
|
||||
"max_tokens": 128000,
|
||||
"model_type": "chat",
|
||||
"is_tools": false
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen3-coder-480b-a35b-instruct",
|
||||
"tags": "LLM,CHAT,262K",
|
||||
"max_tokens": 262144,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen3-next-80b-a3b-instruct",
|
||||
"tags": "LLM,CHAT,65K",
|
||||
"max_tokens": 65536,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen3-next-80b-a3b-thinking",
|
||||
"tags": "LLM,CHAT,65K",
|
||||
"max_tokens": 65536,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "sao10k/l3-70b-euryale-v2.1",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "sao10k/l3-8b-lunaris",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "sao10k/l31-70b-euryale-v2.2",
|
||||
"tags": "LLM,CHAT,8K",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "thudm/glm-4.1v-9b-thinking",
|
||||
"tags": "LLM,CHAT,65K",
|
||||
"max_tokens": 65536,
|
||||
"model_type": "chat",
|
||||
"is_tools": false
|
||||
},
|
||||
{
|
||||
"llm_name": "zai-org/glm-4.5",
|
||||
"tags": "LLM,CHAT,131K",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "zai-org/glm-4.5v",
|
||||
"tags": "LLM,CHAT,65K",
|
||||
"max_tokens": 65536,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "baai/bge-m3",
|
||||
"tags": "TEXT EMBEDDING,8K",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "embedding"
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen3-embedding-0.6b",
|
||||
"tags": "TEXT EMBEDDING,32K",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "embedding"
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen3-embedding-8b",
|
||||
"tags": "TEXT EMBEDDING,32K",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "embedding"
|
||||
},
|
||||
{
|
||||
"llm_name": "baai/bge-reranker-v2-m3",
|
||||
"tags": "RE-RANK,8K",
|
||||
"max_tokens": 8000,
|
||||
"model_type": "reranker"
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen3-reranker-8b",
|
||||
"tags": "RE-RANK,32K",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "reranker"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@ -186,9 +186,6 @@ class DoclingParser(RAGFlowPdfParser):
|
||||
yield (DoclingContentType.EQUATION.value, text, bbox)
|
||||
|
||||
def _transfer_to_sections(self, doc) -> list[tuple[str, str]]:
|
||||
"""
|
||||
和 MinerUParser 保持一致:返回 [(section_text, line_tag), ...]
|
||||
"""
|
||||
sections: list[tuple[str, str]] = []
|
||||
for typ, payload, bbox in self._iter_doc_items(doc):
|
||||
if typ == DoclingContentType.TEXT.value:
|
||||
|
||||
@ -34,6 +34,7 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
|
||||
if isinstance(figure_data[1], Image.Image)
|
||||
]
|
||||
|
||||
|
||||
def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
|
||||
try:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
@ -50,7 +51,8 @@ def vision_figure_parser_docx_wrapper(sections,tbls,callback=None,**kwargs):
|
||||
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
|
||||
return tbls
|
||||
|
||||
def vision_figure_parser_pdf_wrapper(tbls,callback=None,**kwargs):
|
||||
|
||||
def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs):
|
||||
try:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
|
||||
@ -72,6 +74,7 @@ def vision_figure_parser_pdf_wrapper(tbls,callback=None,**kwargs):
|
||||
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
|
||||
return tbls
|
||||
|
||||
|
||||
shared_executor = ThreadPoolExecutor(max_workers=10)
|
||||
|
||||
|
||||
|
||||
@ -434,7 +434,7 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
if not section.strip():
|
||||
section = "FAILED TO PARSE TABLE"
|
||||
case MinerUContentType.IMAGE:
|
||||
section = "".join(output["image_caption"]) + "\n" + "".join(output["image_footnote"])
|
||||
section = "".join(output.get("image_caption", [])) + "\n" + "".join(output.get("image_footnote", []))
|
||||
case MinerUContentType.EQUATION:
|
||||
section = output["text"]
|
||||
case MinerUContentType.CODE:
|
||||
|
||||
@ -117,7 +117,6 @@ def load_model(model_dir, nm, device_id: int | None = None):
|
||||
providers=['CUDAExecutionProvider'],
|
||||
provider_options=[cuda_provider_options]
|
||||
)
|
||||
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(provider_device_id))
|
||||
logging.info(f"load_model {model_file_path} uses GPU (device {provider_device_id}, gpu_mem_limit={cuda_provider_options['gpu_mem_limit']}, arena_strategy={arena_strategy})")
|
||||
else:
|
||||
sess = ort.InferenceSession(
|
||||
|
||||
@ -71,7 +71,7 @@ for arg in "$@"; do
|
||||
ENABLE_TASKEXECUTOR=0
|
||||
shift
|
||||
;;
|
||||
--disable-datasyn)
|
||||
--disable-datasync)
|
||||
ENABLE_DATASYNC=0
|
||||
shift
|
||||
;;
|
||||
|
||||
@ -12,6 +12,10 @@ The RAGFlow Admin UI is a web-based interface that provides comprehensive system
|
||||
|
||||
To access the RAGFlow admin UI, append `/admin` to the web UI's address, e.g. `http://[RAGFLOW_WEB_UI_ADDR]/admin`, replace `[RAGFLOW_WEB_UI_ADDR]` with real RAGFlow web UI address.
|
||||
|
||||
### Default Credentials
|
||||
| Username | Password |
|
||||
|----------|----------|
|
||||
| `admin@ragflow.io` | `admin` |
|
||||
|
||||
## Admin UI Overview
|
||||
|
||||
|
||||
8
docs/guides/dataset/add_data_source/_category_.json
Normal file
8
docs/guides/dataset/add_data_source/_category_.json
Normal file
@ -0,0 +1,8 @@
|
||||
{
|
||||
"label": "Add data source",
|
||||
"position": 18,
|
||||
"link": {
|
||||
"type": "generated-index",
|
||||
"description": "Add various data sources"
|
||||
}
|
||||
}
|
||||
137
docs/guides/dataset/add_data_source/add_google_drive.md
Normal file
137
docs/guides/dataset/add_data_source/add_google_drive.md
Normal file
@ -0,0 +1,137 @@
|
||||
---
|
||||
sidebar_position: 3
|
||||
slug: /add_google_drive
|
||||
---
|
||||
|
||||
# Add Google Drive
|
||||
|
||||
## 1. Create a Google Cloud Project
|
||||
|
||||
You can either create a dedicated project for RAGFlow or use an existing
|
||||
Google Cloud external project.
|
||||
|
||||
**Steps:**
|
||||
1. Open the project creation page\
|
||||
`https://console.cloud.google.com/projectcreate`
|
||||

|
||||
2. Select **External** as the Audience
|
||||

|
||||
3. Click **Create**
|
||||

|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
## 2. Configure OAuth Consent Screen
|
||||
|
||||
1. Go to **APIs & Services → OAuth consent screen**
|
||||
2. Ensure **User Type = External**
|
||||

|
||||
3. Add your test users under **Test Users** by entering email addresses
|
||||

|
||||

|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
## 3. Create OAuth Client Credentials
|
||||
|
||||
1. Navigate to:\
|
||||
`https://console.cloud.google.com/auth/clients`
|
||||
2. Create a **Web Application**
|
||||

|
||||
3. Enter a name for the client
|
||||
4. Add the following **Authorized Redirect URIs**:
|
||||
|
||||
```
|
||||
http://localhost:9380/v1/connector/google-drive/oauth/web/callback
|
||||
```
|
||||
|
||||
### If using Docker deployment:
|
||||
|
||||
**Authorized JavaScript origin:**
|
||||
```
|
||||
http://localhost:80
|
||||
```
|
||||
|
||||

|
||||
### If running from source:
|
||||
**Authorized JavaScript origin:**
|
||||
```
|
||||
http://localhost:9222
|
||||
```
|
||||
|
||||

|
||||
5. After saving, click **Download JSON**. This file will later be
|
||||
uploaded into RAGFlow.
|
||||
|
||||

|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
## 4. Add Scopes
|
||||
|
||||
1. Open **Data Access → Add or remove scopes**
|
||||
|
||||
2. Paste and add the following entries:
|
||||
|
||||
```
|
||||
https://www.googleapis.com/auth/drive.readonly
|
||||
https://www.googleapis.com/auth/drive.metadata.readonly
|
||||
https://www.googleapis.com/auth/admin.directory.group.readonly
|
||||
https://www.googleapis.com/auth/admin.directory.user.readonly
|
||||
```
|
||||
|
||||

|
||||
3. Update and Save changes
|
||||
|
||||

|
||||

|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
## 5. Enable Required APIs
|
||||
Navigate to the Google API Library:\
|
||||
`https://console.cloud.google.com/apis/library`
|
||||

|
||||
|
||||
Enable the following APIs:
|
||||
|
||||
- Google Drive API
|
||||
- Admin SDK API
|
||||
- Google Sheets API
|
||||
- Google Docs API
|
||||
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
------------------------------------------------------------------------
|
||||
|
||||
## 6. Add Google Drive As a Data Source in RAGFlow
|
||||
|
||||
1. Go to **Data Sources** inside RAGFlow
|
||||
2. Select **Google Drive**
|
||||
3. Upload the previously downloaded JSON credentials
|
||||

|
||||
4. Enter the shared Google Drive folder link (https://drive.google.com/drive), such as:
|
||||

|
||||
|
||||
5. Click **Authorize with Google**
|
||||
A browser window will appear.
|
||||

|
||||
Click: - **Continue** - **Select All → Continue** - Authorization should
|
||||
succeed - Select **OK** to add the data source
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
{
|
||||
"label": "Best practices",
|
||||
"position": 11,
|
||||
"position": 19,
|
||||
"link": {
|
||||
"type": "generated-index",
|
||||
"description": "Best practices on configuring a dataset."
|
||||
|
||||
@ -64,7 +64,10 @@ The Admin CLI and Admin Service form a client-server architectural suite for RAG
|
||||
|
||||
- -p: RAGFlow admin server port
|
||||
|
||||
## Default administrative account
|
||||
|
||||
- Username: admin@ragflow.io
|
||||
- Password: admin
|
||||
|
||||
## Supported Commands
|
||||
|
||||
|
||||
@ -974,6 +974,237 @@ Failure:
|
||||
|
||||
---
|
||||
|
||||
### Construct knowledge graph
|
||||
|
||||
**POST** `/api/v1/datasets/{dataset_id}/run_graphrag`
|
||||
|
||||
Constructs a knowledge graph from a specified dataset.
|
||||
|
||||
#### Request
|
||||
|
||||
- Method: POST
|
||||
- URL: `/api/v1/datasets/{dataset_id}/run_graphrag`
|
||||
- Headers:
|
||||
- `'Authorization: Bearer <YOUR_API_KEY>'`
|
||||
|
||||
##### Request example
|
||||
|
||||
```bash
|
||||
curl --request POST \
|
||||
--url http://{address}/api/v1/datasets/{dataset_id}/run_graphrag \
|
||||
--header 'Authorization: Bearer <YOUR_API_KEY>'
|
||||
```
|
||||
|
||||
##### Request parameters
|
||||
|
||||
- `dataset_id`: (*Path parameter*)
|
||||
The ID of the target dataset.
|
||||
|
||||
#### Response
|
||||
|
||||
Success:
|
||||
|
||||
```json
|
||||
{
|
||||
"code":0,
|
||||
"data":{
|
||||
"graphrag_task_id":"e498de54bfbb11f0ba028f704583b57b"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Failure:
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "Invalid Dataset ID"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Get knowledge graph construction status
|
||||
|
||||
**GET** `/api/v1/datasets/{dataset_id}/trace_graphrag`
|
||||
|
||||
Retrieves the knowledge graph construction status for a specified dataset.
|
||||
|
||||
#### Request
|
||||
|
||||
- Method: GET
|
||||
- URL: `/api/v1/datasets/{dataset_id}/trace_graphrag`
|
||||
- Headers:
|
||||
- `'Authorization: Bearer <YOUR_API_KEY>'`
|
||||
|
||||
##### Request example
|
||||
|
||||
```bash
|
||||
curl --request GET \
|
||||
--url http://{address}/api/v1/datasets/{dataset_id}/trace_graphrag \
|
||||
--header 'Authorization: Bearer <YOUR_API_KEY>'
|
||||
```
|
||||
|
||||
##### Request parameters
|
||||
|
||||
- `dataset_id`: (*Path parameter*)
|
||||
The ID of the target dataset.
|
||||
|
||||
#### Response
|
||||
|
||||
Success:
|
||||
|
||||
```json
|
||||
{
|
||||
"code":0,
|
||||
"data":{
|
||||
"begin_at":"Wed, 12 Nov 2025 19:36:56 GMT",
|
||||
"chunk_ids":"",
|
||||
"create_date":"Wed, 12 Nov 2025 19:36:56 GMT",
|
||||
"create_time":1762947416350,
|
||||
"digest":"39e43572e3dcd84f",
|
||||
"doc_id":"44661c10bde211f0bc93c164a47ffc40",
|
||||
"from_page":100000000,
|
||||
"id":"e498de54bfbb11f0ba028f704583b57b",
|
||||
"priority":0,
|
||||
"process_duration":2.45419,
|
||||
"progress":1.0,
|
||||
"progress_msg":"19:36:56 created task graphrag\n19:36:57 Task has been received.\n19:36:58 [GraphRAG] doc:083661febe2411f0bc79456921e5745f has no available chunks, skip generation.\n19:36:58 [GraphRAG] build_subgraph doc:44661c10bde211f0bc93c164a47ffc40 start (chunks=1, timeout=10000000000s)\n19:36:58 Graph already contains 44661c10bde211f0bc93c164a47ffc40\n19:36:58 [GraphRAG] build_subgraph doc:44661c10bde211f0bc93c164a47ffc40 empty\n19:36:58 [GraphRAG] kb:33137ed0bde211f0bc93c164a47ffc40 no subgraphs generated successfully, end.\n19:36:58 Knowledge Graph done (0.72s)","retry_count":1,
|
||||
"task_type":"graphrag",
|
||||
"to_page":100000000,
|
||||
"update_date":"Wed, 12 Nov 2025 19:36:58 GMT",
|
||||
"update_time":1762947418454
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Failure:
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "Invalid Dataset ID"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Construct RAPTOR
|
||||
|
||||
**POST** `/api/v1/datasets/{dataset_id}/run_raptor`
|
||||
|
||||
Construct a RAPTOR from a specified dataset.
|
||||
|
||||
#### Request
|
||||
|
||||
- Method: POST
|
||||
- URL: `/api/v1/datasets/{dataset_id}/run_raptor`
|
||||
- Headers:
|
||||
- `'Authorization: Bearer <YOUR_API_KEY>'`
|
||||
|
||||
##### Request example
|
||||
|
||||
```bash
|
||||
curl --request POST \
|
||||
--url http://{address}/api/v1/datasets/{dataset_id}/run_raptor \
|
||||
--header 'Authorization: Bearer <YOUR_API_KEY>'
|
||||
```
|
||||
|
||||
##### Request parameters
|
||||
|
||||
- `dataset_id`: (*Path parameter*)
|
||||
The ID of the target dataset.
|
||||
|
||||
#### Response
|
||||
|
||||
Success:
|
||||
|
||||
```json
|
||||
{
|
||||
"code":0,
|
||||
"data":{
|
||||
"raptor_task_id":"50d3c31cbfbd11f0ba028f704583b57b"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Failure:
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "Invalid Dataset ID"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Get RAPTOR construction status
|
||||
|
||||
**GET** `/api/v1/datasets/{dataset_id}/trace_raptor`
|
||||
|
||||
Retrieves the RAPTOR construction status for a specified dataset.
|
||||
|
||||
#### Request
|
||||
|
||||
- Method: GET
|
||||
- URL: `/api/v1/datasets/{dataset_id}/trace_raptor`
|
||||
- Headers:
|
||||
- `'Authorization: Bearer <YOUR_API_KEY>'`
|
||||
|
||||
##### Request example
|
||||
|
||||
```bash
|
||||
curl --request GET \
|
||||
--url http://{address}/api/v1/datasets/{dataset_id}/trace_raptor \
|
||||
--header 'Authorization: Bearer <YOUR_API_KEY>'
|
||||
```
|
||||
|
||||
##### Request parameters
|
||||
|
||||
- `dataset_id`: (*Path parameter*)
|
||||
The ID of the target dataset.
|
||||
|
||||
#### Response
|
||||
|
||||
Success:
|
||||
|
||||
```json
|
||||
{
|
||||
"code":0,
|
||||
"data":{
|
||||
"begin_at":"Wed, 12 Nov 2025 19:47:07 GMT",
|
||||
"chunk_ids":"",
|
||||
"create_date":"Wed, 12 Nov 2025 19:47:07 GMT",
|
||||
"create_time":1762948027427,
|
||||
"digest":"8b279a6248cb8fc6",
|
||||
"doc_id":"44661c10bde211f0bc93c164a47ffc40",
|
||||
"from_page":100000000,
|
||||
"id":"50d3c31cbfbd11f0ba028f704583b57b",
|
||||
"priority":0,
|
||||
"process_duration":0.948244,
|
||||
"progress":1.0,
|
||||
"progress_msg":"19:47:07 created task raptor\n19:47:07 Task has been received.\n19:47:07 Processing...\n19:47:07 Processing...\n19:47:07 Indexing done (0.01s).\n19:47:07 Task done (0.29s)",
|
||||
"retry_count":1,
|
||||
"task_type":"raptor",
|
||||
"to_page":100000000,
|
||||
"update_date":"Wed, 12 Nov 2025 19:47:07 GMT",
|
||||
"update_time":1762948027948
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Failure:
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "Invalid Dataset ID"
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## FILE MANAGEMENT WITHIN DATASET
|
||||
|
||||
---
|
||||
|
||||
@ -67,6 +67,7 @@ A complete list of models supported by RAGFlow, which will continue to expand.
|
||||
| 302.AI | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
| CometAPI | :heavy_check_mark: | :heavy_check_mark: | | | | |
|
||||
| DeerAPI | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | | :heavy_check_mark: |
|
||||
| Jiekou.AI | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | | |
|
||||
|
||||
```mdx-code-block
|
||||
</APITable>
|
||||
|
||||
@ -693,7 +693,7 @@ Released on August 26, 2024.
|
||||
- Incorporates monitoring for the task executor.
|
||||
- Introduces Agent tools **GitHub**, **DeepL**, **BaiduFanyi**, **QWeather**, and **GoogleScholar**.
|
||||
- Supports chunking of EML files.
|
||||
- Supports more LLMs or model services: **GPT-4o-mini**, **PerfXCloud**, **TogetherAI**, **Upstage**, **Novita AI**, **01.AI**, **SiliconFlow**, **PPIO**, **XunFei Spark**, **Baidu Yiyan**, and **Tencent Hunyuan**.
|
||||
- Supports more LLMs or model services: **GPT-4o-mini**, **PerfXCloud**, **TogetherAI**, **Upstage**, **Novita AI**, **01.AI**, **SiliconFlow**, **PPIO**, **XunFei Spark**, **Jiekou.AI**, **Baidu Yiyan**, and **Tencent Hunyuan**.
|
||||
|
||||
## v0.9.0
|
||||
|
||||
|
||||
@ -114,7 +114,7 @@ class Extractor:
|
||||
async def extract_all(doc_id, chunks, max_concurrency=MAX_CONCURRENT_PROCESS_AND_EXTRACT_CHUNK, task_id=""):
|
||||
out_results = []
|
||||
error_count = 0
|
||||
max_errors = 3
|
||||
max_errors = int(os.environ.get("GRAPHRAG_MAX_ERRORS", 3))
|
||||
|
||||
limiter = trio.Semaphore(max_concurrency)
|
||||
|
||||
|
||||
@ -69,7 +69,7 @@ class KGSearch(Dealer):
|
||||
def _ent_info_from_(self, es_res, sim_thr=0.3):
|
||||
res = {}
|
||||
flds = ["content_with_weight", "_score", "entity_kwd", "rank_flt", "n_hop_with_weight"]
|
||||
es_res = self.dataStore.getFields(es_res, flds)
|
||||
es_res = self.dataStore.get_fields(es_res, flds)
|
||||
for _, ent in es_res.items():
|
||||
for f in flds:
|
||||
if f in ent and ent[f] is None:
|
||||
@ -88,7 +88,7 @@ class KGSearch(Dealer):
|
||||
|
||||
def _relation_info_from_(self, es_res, sim_thr=0.3):
|
||||
res = {}
|
||||
es_res = self.dataStore.getFields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
|
||||
es_res = self.dataStore.get_fields(es_res, ["content_with_weight", "_score", "from_entity_kwd", "to_entity_kwd",
|
||||
"weight_int"])
|
||||
for _, ent in es_res.items():
|
||||
if get_float(ent["_score"]) < sim_thr:
|
||||
@ -300,7 +300,7 @@ class KGSearch(Dealer):
|
||||
fltr["entities_kwd"] = entities
|
||||
comm_res = self.dataStore.search(fields, [], fltr, [],
|
||||
OrderByExpr(), 0, topn, idxnms, kb_ids)
|
||||
comm_res_fields = self.dataStore.getFields(comm_res, fields)
|
||||
comm_res_fields = self.dataStore.get_fields(comm_res, fields)
|
||||
txts = []
|
||||
for ii, (_, row) in enumerate(comm_res_fields.items()):
|
||||
obj = json.loads(row["content_with_weight"])
|
||||
|
||||
@ -382,7 +382,7 @@ async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
|
||||
fields2 = settings.docStoreConn.getFields(res, fields)
|
||||
fields2 = settings.docStoreConn.get_fields(res, fields)
|
||||
graph_doc_ids = set()
|
||||
for chunk_id in fields2.keys():
|
||||
graph_doc_ids = set(fields2[chunk_id]["source_id"])
|
||||
@ -591,8 +591,8 @@ async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
|
||||
es_res = await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.search(flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id])
|
||||
)
|
||||
# tot = settings.docStoreConn.getTotal(es_res)
|
||||
es_res = settings.docStoreConn.getFields(es_res, flds)
|
||||
# tot = settings.docStoreConn.get_total(es_res)
|
||||
es_res = settings.docStoreConn.get_fields(es_res, flds)
|
||||
|
||||
if len(es_res) == 0:
|
||||
break
|
||||
|
||||
@ -145,6 +145,7 @@ dependencies = [
|
||||
"markdownify>=1.2.0",
|
||||
"captcha>=0.7.1",
|
||||
"pip>=25.2",
|
||||
"pypandoc>=1.16",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
@ -49,6 +49,7 @@ class SupportedLiteLLMProvider(StrEnum):
|
||||
Lingyi_AI = "01.AI"
|
||||
GiteeAI = "GiteeAI"
|
||||
AI_302 = "302.AI"
|
||||
JiekouAI = "Jiekou.AI"
|
||||
|
||||
|
||||
FACTORY_DEFAULT_BASE_URL = {
|
||||
@ -69,6 +70,7 @@ FACTORY_DEFAULT_BASE_URL = {
|
||||
SupportedLiteLLMProvider.GiteeAI: "https://ai.gitee.com/v1/",
|
||||
SupportedLiteLLMProvider.AI_302: "https://api.302.ai/v1",
|
||||
SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/",
|
||||
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
|
||||
}
|
||||
|
||||
|
||||
@ -99,6 +101,7 @@ LITELLM_PROVIDER_PREFIX = {
|
||||
SupportedLiteLLMProvider.Lingyi_AI: "openai/",
|
||||
SupportedLiteLLMProvider.GiteeAI: "openai/",
|
||||
SupportedLiteLLMProvider.AI_302: "openai/",
|
||||
SupportedLiteLLMProvider.JiekouAI: "openai/",
|
||||
}
|
||||
|
||||
ChatModel = globals().get("ChatModel", {})
|
||||
|
||||
@ -22,7 +22,6 @@ import re
|
||||
import time
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from typing import Any, Protocol
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import json_repair
|
||||
@ -65,10 +64,6 @@ LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小
|
||||
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
|
||||
|
||||
|
||||
class ToolCallSession(Protocol):
|
||||
def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ...
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, key, model_name, base_url, **kwargs):
|
||||
timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600))
|
||||
@ -1402,6 +1397,7 @@ class LiteLLMBase(ABC):
|
||||
"01.AI",
|
||||
"GiteeAI",
|
||||
"302.AI",
|
||||
"Jiekou.AI",
|
||||
]
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import re
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
@ -32,7 +33,6 @@ from rag.nlp import is_english
|
||||
from rag.prompts.generator import vision_llm_describe_prompt
|
||||
from common.token_utils import num_tokens_from_string, total_token_count_from_response
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
def __init__(self, **kwargs):
|
||||
# Configure retry parameters
|
||||
@ -208,6 +208,7 @@ class GptV4(Base):
|
||||
model=self.model_name,
|
||||
messages=self.prompt(b64),
|
||||
extra_body=self.extra_body,
|
||||
unused = None,
|
||||
)
|
||||
return res.choices[0].message.content.strip(), total_token_count_from_response(res)
|
||||
|
||||
@ -324,6 +325,122 @@ class Zhipu4V(GptV4):
|
||||
Base.__init__(self, **kwargs)
|
||||
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
gen_conf = self._clean_conf_plealty(gen_conf)
|
||||
return gen_conf
|
||||
|
||||
|
||||
def _clean_conf_plealty(self, gen_conf):
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
del gen_conf["frequency_penalty"]
|
||||
return gen_conf
|
||||
|
||||
|
||||
def _request(self, msg, stream, gen_conf={}):
|
||||
response = requests.post(
|
||||
self.base_url,
|
||||
json={
|
||||
"model": self.model_name,
|
||||
"messages": msg,
|
||||
"stream": stream,
|
||||
**gen_conf
|
||||
},
|
||||
headers= {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
)
|
||||
return response.json()
|
||||
|
||||
|
||||
def chat(self, system, history, gen_conf, images=None, stream=False, **kwargs):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
|
||||
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=False, **gen_conf)
|
||||
content = response.choices[0].message.content.strip()
|
||||
|
||||
cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip()
|
||||
return cleaned, total_token_count_from_response(response)
|
||||
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf, images=None, **kwargs):
|
||||
from rag.llm.chat_model import LENGTH_NOTIFICATION_CN, LENGTH_NOTIFICATION_EN
|
||||
from rag.nlp import is_chinese
|
||||
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=True, **gen_conf)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:
|
||||
continue
|
||||
delta = resp.choices[0].delta.content
|
||||
ans = delta
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
tk_count = total_token_count_from_response(resp)
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
tk_count = total_token_count_from_response(resp)
|
||||
yield ans
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield tk_count
|
||||
|
||||
|
||||
def describe(self, image):
|
||||
return self.describe_with_prompt(image)
|
||||
|
||||
|
||||
def describe_with_prompt(self, image, prompt=None):
|
||||
b64 = self.image2base64(image)
|
||||
if prompt is None:
|
||||
prompt = "Describe this image."
|
||||
|
||||
# Chat messages
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": { "url": b64 }
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
resp = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
stream=False
|
||||
)
|
||||
|
||||
content = resp.choices[0].message.content.strip()
|
||||
cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip()
|
||||
|
||||
return cleaned, num_tokens_from_string(cleaned)
|
||||
|
||||
|
||||
class StepFunCV(GptV4):
|
||||
_FACTORY_NAME = "StepFun"
|
||||
|
||||
|
||||
@ -931,3 +931,12 @@ class DeerAPIEmbed(OpenAIEmbed):
|
||||
if not base_url:
|
||||
base_url = "https://api.deerapi.com/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class JiekouAIEmbed(OpenAIEmbed):
|
||||
_FACTORY_NAME = "Jiekou.AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.jiekou.ai/openai/v1/embeddings"):
|
||||
if not base_url:
|
||||
base_url = "https://api.jiekou.ai/openai/v1/embeddings"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
@ -489,3 +489,12 @@ class Ai302Rerank(Base):
|
||||
if not base_url:
|
||||
base_url = "https://api.302.ai/v1/rerank"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class JiekouAIRerank(JinaRerank):
|
||||
_FACTORY_NAME = "Jiekou.AI"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.jiekou.ai/openai/v1/rerank"):
|
||||
if not base_url:
|
||||
base_url = "https://api.jiekou.ai/openai/v1/rerank"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
@ -155,13 +155,13 @@ def qbullets_category(sections):
|
||||
if re.match(pro, sec) and not not_bullet(sec):
|
||||
hits[i] += 1
|
||||
break
|
||||
maxium = 0
|
||||
maximum = 0
|
||||
res = -1
|
||||
for i, h in enumerate(hits):
|
||||
if h <= maxium:
|
||||
if h <= maximum:
|
||||
continue
|
||||
res = i
|
||||
maxium = h
|
||||
maximum = h
|
||||
return res, QUESTION_PATTERN[res]
|
||||
|
||||
|
||||
@ -222,13 +222,13 @@ def bullets_category(sections):
|
||||
if re.match(p, sec) and not not_bullet(sec):
|
||||
hits[i] += 1
|
||||
break
|
||||
maxium = 0
|
||||
maximum = 0
|
||||
res = -1
|
||||
for i, h in enumerate(hits):
|
||||
if h <= maxium:
|
||||
if h <= maximum:
|
||||
continue
|
||||
res = i
|
||||
maxium = h
|
||||
maximum = h
|
||||
return res
|
||||
|
||||
|
||||
@ -482,7 +482,7 @@ def tree_merge(bull, sections, depth):
|
||||
root = Node(level=0, depth=target_level, texts=[])
|
||||
root.build_tree(lines)
|
||||
|
||||
return [("\n").join(element) for element in root.get_tree() if element]
|
||||
return [element for element in root.get_tree() if element]
|
||||
|
||||
def hierarchical_merge(bull, sections, depth):
|
||||
|
||||
@ -723,47 +723,40 @@ def naive_merge_docx(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
||||
if not sections:
|
||||
return [], []
|
||||
|
||||
cks = [""]
|
||||
images = [None]
|
||||
tk_nums = [0]
|
||||
cks = []
|
||||
images = []
|
||||
tk_nums = []
|
||||
|
||||
def add_chunk(t, image, pos=""):
|
||||
nonlocal cks, tk_nums, delimiter
|
||||
nonlocal cks, images, tk_nums
|
||||
tnum = num_tokens_from_string(t)
|
||||
if tnum < 8:
|
||||
pos = ""
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num:
|
||||
if t.find(pos) < 0:
|
||||
|
||||
if not cks or tk_nums[-1] > chunk_token_num:
|
||||
# new chunk
|
||||
if pos and t.find(pos) < 0:
|
||||
t += pos
|
||||
cks.append(t)
|
||||
images.append(image)
|
||||
tk_nums.append(tnum)
|
||||
else:
|
||||
if cks[-1].find(pos) < 0:
|
||||
# add to last chunk
|
||||
if pos and cks[-1].find(pos) < 0:
|
||||
t += pos
|
||||
cks[-1] += t
|
||||
images[-1] = concat_img(images[-1], image)
|
||||
tk_nums[-1] += tnum
|
||||
|
||||
dels = get_delimiters(delimiter)
|
||||
line = ""
|
||||
for sec, image in sections:
|
||||
if not image:
|
||||
line += sec + "\n"
|
||||
continue
|
||||
split_sec = re.split(r"(%s)" % dels, line + sec)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk("\n"+sub_sec, image,"")
|
||||
line = ""
|
||||
pattern = r"(%s)" % dels
|
||||
|
||||
if line:
|
||||
split_sec = re.split(r"(%s)" % dels, line)
|
||||
for sec, image in sections:
|
||||
split_sec = re.split(pattern, sec)
|
||||
for sub_sec in split_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
if not sub_sec or re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
add_chunk("\n"+sub_sec, image,"")
|
||||
add_chunk("\n" + sub_sec, image, "")
|
||||
|
||||
return cks, images
|
||||
|
||||
|
||||
@ -38,11 +38,11 @@ class FulltextQueryer:
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def subSpecialChar(line):
|
||||
def sub_special_char(line):
|
||||
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
|
||||
|
||||
@staticmethod
|
||||
def isChinese(line):
|
||||
def is_chinese(line):
|
||||
arr = re.split(r"[ \t]+", line)
|
||||
if len(arr) <= 3:
|
||||
return True
|
||||
@ -92,7 +92,7 @@ class FulltextQueryer:
|
||||
otxt = txt
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
|
||||
if not self.isChinese(txt):
|
||||
if not self.is_chinese(txt):
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
tks = rag_tokenizer.tokenize(txt).split()
|
||||
keywords = [t for t in tks if t]
|
||||
@ -163,7 +163,7 @@ class FulltextQueryer:
|
||||
)
|
||||
for m in sm
|
||||
]
|
||||
sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1]
|
||||
sm = [FulltextQueryer.sub_special_char(m) for m in sm if len(m) > 1]
|
||||
sm = [m for m in sm if len(m) > 1]
|
||||
|
||||
if len(keywords) < 32:
|
||||
@ -171,7 +171,7 @@ class FulltextQueryer:
|
||||
keywords.extend(sm)
|
||||
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
|
||||
if len(keywords) < 32:
|
||||
keywords.extend([s for s in tk_syns if s])
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
@ -180,7 +180,7 @@ class FulltextQueryer:
|
||||
if len(keywords) >= 32:
|
||||
break
|
||||
|
||||
tk = FulltextQueryer.subSpecialChar(tk)
|
||||
tk = FulltextQueryer.sub_special_char(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
@ -198,7 +198,7 @@ class FulltextQueryer:
|
||||
syns = " OR ".join(
|
||||
[
|
||||
'"%s"'
|
||||
% rag_tokenizer.tokenize(FulltextQueryer.subSpecialChar(s))
|
||||
% rag_tokenizer.tokenize(FulltextQueryer.sub_special_char(s))
|
||||
for s in syns
|
||||
]
|
||||
)
|
||||
@ -217,17 +217,17 @@ class FulltextQueryer:
|
||||
return None, keywords
|
||||
|
||||
def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7):
|
||||
from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
import numpy as np
|
||||
|
||||
sims = CosineSimilarity([avec], bvecs)
|
||||
sims = cosine_similarity([avec], bvecs)
|
||||
tksim = self.token_similarity(atks, btkss)
|
||||
if np.sum(sims[0]) == 0:
|
||||
return np.array(tksim), tksim, sims[0]
|
||||
return np.array(sims[0]) * vtweight + np.array(tksim) * tkweight, tksim, sims[0]
|
||||
|
||||
def token_similarity(self, atks, btkss):
|
||||
def toDict(tks):
|
||||
def to_dict(tks):
|
||||
if isinstance(tks, str):
|
||||
tks = tks.split()
|
||||
d = defaultdict(int)
|
||||
@ -236,8 +236,8 @@ class FulltextQueryer:
|
||||
d[t] += c
|
||||
return d
|
||||
|
||||
atks = toDict(atks)
|
||||
btkss = [toDict(tks) for tks in btkss]
|
||||
atks = to_dict(atks)
|
||||
btkss = [to_dict(tks) for tks in btkss]
|
||||
return [self.similarity(atks, btks) for btks in btkss]
|
||||
|
||||
def similarity(self, qtwt, dtwt):
|
||||
@ -262,10 +262,10 @@ class FulltextQueryer:
|
||||
keywords = [f'"{k.strip()}"' for k in keywords]
|
||||
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns]
|
||||
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
||||
tk = FulltextQueryer.subSpecialChar(tk)
|
||||
tk = FulltextQueryer.sub_special_char(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
|
||||
@ -35,7 +35,7 @@ class RagTokenizer:
|
||||
def rkey_(self, line):
|
||||
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
|
||||
|
||||
def loadDict_(self, fnm):
|
||||
def _load_dict(self, fnm):
|
||||
logging.info(f"[HUQIE]:Build trie from {fnm}")
|
||||
try:
|
||||
of = open(fnm, "r", encoding='utf-8')
|
||||
@ -85,18 +85,18 @@ class RagTokenizer:
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
|
||||
# load data from dict file and save to trie file
|
||||
self.loadDict_(self.DIR_ + ".txt")
|
||||
self._load_dict(self.DIR_ + ".txt")
|
||||
|
||||
def loadUserDict(self, fnm):
|
||||
def load_user_dict(self, fnm):
|
||||
try:
|
||||
self.trie_ = datrie.Trie.load(fnm + ".trie")
|
||||
return
|
||||
except Exception:
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
self.loadDict_(fnm)
|
||||
self._load_dict(fnm)
|
||||
|
||||
def addUserDict(self, fnm):
|
||||
self.loadDict_(fnm)
|
||||
def add_user_dict(self, fnm):
|
||||
self._load_dict(fnm)
|
||||
|
||||
def _strQ2B(self, ustring):
|
||||
"""Convert full-width characters to half-width characters"""
|
||||
@ -221,7 +221,7 @@ class RagTokenizer:
|
||||
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
|
||||
return tks, B / len(tks) + L + F
|
||||
|
||||
def sortTks_(self, tkslist):
|
||||
def _sort_tokens(self, tkslist):
|
||||
res = []
|
||||
for tfts in tkslist:
|
||||
tks, s = self.score_(tfts)
|
||||
@ -246,7 +246,7 @@ class RagTokenizer:
|
||||
|
||||
return " ".join(res)
|
||||
|
||||
def maxForward_(self, line):
|
||||
def _max_forward(self, line):
|
||||
res = []
|
||||
s = 0
|
||||
while s < len(line):
|
||||
@ -270,7 +270,7 @@ class RagTokenizer:
|
||||
|
||||
return self.score_(res)
|
||||
|
||||
def maxBackward_(self, line):
|
||||
def _max_backward(self, line):
|
||||
res = []
|
||||
s = len(line) - 1
|
||||
while s >= 0:
|
||||
@ -336,8 +336,8 @@ class RagTokenizer:
|
||||
continue
|
||||
|
||||
# use maxforward for the first time
|
||||
tks, s = self.maxForward_(L)
|
||||
tks1, s1 = self.maxBackward_(L)
|
||||
tks, s = self._max_forward(L)
|
||||
tks1, s1 = self._max_backward(L)
|
||||
if self.DEBUG:
|
||||
logging.debug("[FW] {} {}".format(tks, s))
|
||||
logging.debug("[BW] {} {}".format(tks1, s1))
|
||||
@ -369,7 +369,7 @@ class RagTokenizer:
|
||||
# backward tokens from_i to i are different from forward tokens from _j to j.
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
|
||||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||||
res.append(" ".join(self._sort_tokens(tkslist)[0][0]))
|
||||
|
||||
same = 1
|
||||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||
@ -385,7 +385,7 @@ class RagTokenizer:
|
||||
assert "".join(tks1[_i:]) == "".join(tks[_j:])
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
|
||||
res.append(" ".join(self.sortTks_(tkslist)[0][0]))
|
||||
res.append(" ".join(self._sort_tokens(tkslist)[0][0]))
|
||||
|
||||
res = " ".join(res)
|
||||
logging.debug("[TKS] {}".format(self.merge_(res)))
|
||||
@ -413,7 +413,7 @@ class RagTokenizer:
|
||||
if len(tkslist) < 2:
|
||||
res.append(tk)
|
||||
continue
|
||||
stk = self.sortTks_(tkslist)[1][0]
|
||||
stk = self._sort_tokens(tkslist)[1][0]
|
||||
if len(stk) == len(tk):
|
||||
stk = tk
|
||||
else:
|
||||
@ -447,14 +447,13 @@ def is_number(s):
|
||||
|
||||
|
||||
def is_alphabet(s):
|
||||
if (s >= u'\u0041' and s <= u'\u005a') or (
|
||||
s >= u'\u0061' and s <= u'\u007a'):
|
||||
if (u'\u0041' <= s <= u'\u005a') or (u'\u0061' <= s <= u'\u007a'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def naiveQie(txt):
|
||||
def naive_qie(txt):
|
||||
tks = []
|
||||
for t in txt.split():
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]
|
||||
@ -469,14 +468,14 @@ tokenize = tokenizer.tokenize
|
||||
fine_grained_tokenize = tokenizer.fine_grained_tokenize
|
||||
tag = tokenizer.tag
|
||||
freq = tokenizer.freq
|
||||
loadUserDict = tokenizer.loadUserDict
|
||||
addUserDict = tokenizer.addUserDict
|
||||
load_user_dict = tokenizer.load_user_dict
|
||||
add_user_dict = tokenizer.add_user_dict
|
||||
tradi2simp = tokenizer._tradi2simp
|
||||
strQ2B = tokenizer._strQ2B
|
||||
|
||||
if __name__ == '__main__':
|
||||
tknzr = RagTokenizer(debug=True)
|
||||
# huqie.addUserDict("/tmp/tmp.new.tks.dict")
|
||||
# huqie.add_user_dict("/tmp/tmp.new.tks.dict")
|
||||
tks = tknzr.tokenize(
|
||||
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈")
|
||||
logging.info(tknzr.fine_grained_tokenize(tks))
|
||||
@ -506,7 +505,7 @@ if __name__ == '__main__':
|
||||
if len(sys.argv) < 2:
|
||||
sys.exit()
|
||||
tknzr.DEBUG = False
|
||||
tknzr.loadUserDict(sys.argv[1])
|
||||
tknzr.load_user_dict(sys.argv[1])
|
||||
of = open(sys.argv[2], "r")
|
||||
while True:
|
||||
line = of.readline()
|
||||
|
||||
@ -102,7 +102,7 @@ class Dealer:
|
||||
orderBy.asc("top_int")
|
||||
orderBy.desc("create_timestamp_flt")
|
||||
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
total = self.dataStore.getTotal(res)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
else:
|
||||
highlightFields = ["content_ltks", "title_tks"]
|
||||
@ -115,7 +115,7 @@ class Dealer:
|
||||
matchExprs = [matchText]
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
else:
|
||||
matchDense = self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1))
|
||||
@ -127,20 +127,20 @@ class Dealer:
|
||||
|
||||
res = self.dataStore.search(src, highlightFields, filters, matchExprs, orderBy, offset, limit,
|
||||
idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search TOTAL: {}".format(total))
|
||||
|
||||
# If result is empty, try again with lower min_match
|
||||
if total == 0:
|
||||
if filters.get("doc_id"):
|
||||
res = self.dataStore.search(src, [], filters, [], orderBy, offset, limit, idx_names, kb_ids)
|
||||
total = self.dataStore.getTotal(res)
|
||||
total = self.dataStore.get_total(res)
|
||||
else:
|
||||
matchText, _ = self.qryr.question(qst, min_match=0.1)
|
||||
matchDense.extra_options["similarity"] = 0.17
|
||||
res = self.dataStore.search(src, highlightFields, filters, [matchText, matchDense, fusionExpr],
|
||||
orderBy, offset, limit, idx_names, kb_ids, rank_feature=rank_feature)
|
||||
total = self.dataStore.getTotal(res)
|
||||
total = self.dataStore.get_total(res)
|
||||
logging.debug("Dealer.search 2 TOTAL: {}".format(total))
|
||||
|
||||
for k in keywords:
|
||||
@ -153,17 +153,17 @@ class Dealer:
|
||||
kwds.add(kk)
|
||||
|
||||
logging.debug(f"TOTAL: {total}")
|
||||
ids = self.dataStore.getChunkIds(res)
|
||||
ids = self.dataStore.get_chunk_ids(res)
|
||||
keywords = list(kwds)
|
||||
highlight = self.dataStore.getHighlight(res, keywords, "content_with_weight")
|
||||
aggs = self.dataStore.getAggregation(res, "docnm_kwd")
|
||||
highlight = self.dataStore.get_highlight(res, keywords, "content_with_weight")
|
||||
aggs = self.dataStore.get_aggregation(res, "docnm_kwd")
|
||||
return self.SearchResult(
|
||||
total=total,
|
||||
ids=ids,
|
||||
query_vector=q_vec,
|
||||
aggregation=aggs,
|
||||
highlight=highlight,
|
||||
field=self.dataStore.getFields(res, src + ["_score"]),
|
||||
field=self.dataStore.get_fields(res, src + ["_score"]),
|
||||
keywords=keywords
|
||||
)
|
||||
|
||||
@ -347,7 +347,7 @@ class Dealer:
|
||||
## For rank feature(tag_fea) scores.
|
||||
rank_fea = self._rank_feature_scores(rank_feature, sres)
|
||||
|
||||
return tkweight * (np.array(tksim)+rank_fea) + vtweight * vtsim, tksim, vtsim
|
||||
return tkweight * np.array(tksim) + vtweight * vtsim + rank_fea, tksim, vtsim
|
||||
|
||||
def hybrid_similarity(self, ans_embd, ins_embd, ans, inst):
|
||||
return self.qryr.hybrid_similarity(ans_embd,
|
||||
@ -488,7 +488,7 @@ class Dealer:
|
||||
for p in range(offset, max_count, bs):
|
||||
es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, bs, index_name(tenant_id),
|
||||
kb_ids)
|
||||
dict_chunks = self.dataStore.getFields(es_res, fields)
|
||||
dict_chunks = self.dataStore.get_fields(es_res, fields)
|
||||
for id, doc in dict_chunks.items():
|
||||
doc["id"] = id
|
||||
if dict_chunks:
|
||||
@ -501,11 +501,11 @@ class Dealer:
|
||||
if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]):
|
||||
return []
|
||||
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
|
||||
return self.dataStore.getAggregation(res, "tag_kwd")
|
||||
return self.dataStore.get_aggregation(res, "tag_kwd")
|
||||
|
||||
def all_tags_in_portion(self, tenant_id: str, kb_ids: list[str], S=1000):
|
||||
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
|
||||
res = self.dataStore.getAggregation(res, "tag_kwd")
|
||||
res = self.dataStore.get_aggregation(res, "tag_kwd")
|
||||
total = np.sum([c for _, c in res])
|
||||
return {t: (c + 1) / (total + S) for t, c in res}
|
||||
|
||||
@ -513,7 +513,7 @@ class Dealer:
|
||||
idx_nm = index_name(tenant_id)
|
||||
match_txt = self.qryr.paragraph(doc["title_tks"] + " " + doc["content_ltks"], doc.get("important_kwd", []), keywords_topn)
|
||||
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nm, kb_ids, ["tag_kwd"])
|
||||
aggs = self.dataStore.getAggregation(res, "tag_kwd")
|
||||
aggs = self.dataStore.get_aggregation(res, "tag_kwd")
|
||||
if not aggs:
|
||||
return False
|
||||
cnt = np.sum([c for _, c in aggs])
|
||||
@ -529,7 +529,7 @@ class Dealer:
|
||||
idx_nms = [index_name(tid) for tid in tenant_ids]
|
||||
match_txt, _ = self.qryr.question(question, min_match=0.0)
|
||||
res = self.dataStore.search([], [], {}, [match_txt], OrderByExpr(), 0, 0, idx_nms, kb_ids, ["tag_kwd"])
|
||||
aggs = self.dataStore.getAggregation(res, "tag_kwd")
|
||||
aggs = self.dataStore.get_aggregation(res, "tag_kwd")
|
||||
if not aggs:
|
||||
return {}
|
||||
cnt = np.sum([c for _, c in aggs])
|
||||
@ -552,7 +552,7 @@ class Dealer:
|
||||
es_res = self.dataStore.search(["content_with_weight"], [], {"doc_id": doc_id, "toc_kwd": "toc"}, [], OrderByExpr(), 0, 128, idx_nms,
|
||||
kb_ids)
|
||||
toc = []
|
||||
dict_chunks = self.dataStore.getFields(es_res, ["content_with_weight"])
|
||||
dict_chunks = self.dataStore.get_fields(es_res, ["content_with_weight"])
|
||||
for _, doc in dict_chunks.items():
|
||||
try:
|
||||
toc.extend(json.loads(doc["content_with_weight"]))
|
||||
|
||||
@ -113,20 +113,20 @@ class Dealer:
|
||||
res.append(tk)
|
||||
return res
|
||||
|
||||
def tokenMerge(self, tks):
|
||||
def oneTerm(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
||||
def token_merge(self, tks):
|
||||
def one_term(t): return len(t) == 1 or re.match(r"[0-9a-z]{1,2}$", t)
|
||||
|
||||
res, i = [], 0
|
||||
while i < len(tks):
|
||||
j = i
|
||||
if i == 0 and oneTerm(tks[i]) and len(
|
||||
if i == 0 and one_term(tks[i]) and len(
|
||||
tks) > 1 and (len(tks[i + 1]) > 1 and not re.match(r"[0-9a-zA-Z]", tks[i + 1])): # 多 工位
|
||||
res.append(" ".join(tks[0:2]))
|
||||
i = 2
|
||||
continue
|
||||
|
||||
while j < len(
|
||||
tks) and tks[j] and tks[j] not in self.stop_words and oneTerm(tks[j]):
|
||||
tks) and tks[j] and tks[j] not in self.stop_words and one_term(tks[j]):
|
||||
j += 1
|
||||
if j - i > 1:
|
||||
if j - i < 5:
|
||||
@ -232,7 +232,7 @@ class Dealer:
|
||||
tw = list(zip(tks, wts))
|
||||
else:
|
||||
for tk in tks:
|
||||
tt = self.tokenMerge(self.pretoken(tk, True))
|
||||
tt = self.token_merge(self.pretoken(tk, True))
|
||||
idf1 = np.array([idf(freq(t), 10000000) for t in tt])
|
||||
idf2 = np.array([idf(df(t), 1000000000) for t in tt])
|
||||
wts = (0.3 * idf1 + 0.7 * idf2) * \
|
||||
|
||||
154
rag/raptor.py
154
rag/raptor.py
@ -15,27 +15,35 @@
|
||||
#
|
||||
import logging
|
||||
import re
|
||||
import umap
|
||||
|
||||
import numpy as np
|
||||
from sklearn.mixture import GaussianMixture
|
||||
import trio
|
||||
import umap
|
||||
from sklearn.mixture import GaussianMixture
|
||||
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.connection_utils import timeout
|
||||
from common.exceptions import TaskCanceledException
|
||||
from common.token_utils import truncate
|
||||
from graphrag.utils import (
|
||||
get_llm_cache,
|
||||
chat_limiter,
|
||||
get_embed_cache,
|
||||
get_llm_cache,
|
||||
set_embed_cache,
|
||||
set_llm_cache,
|
||||
chat_limiter,
|
||||
)
|
||||
from common.token_utils import truncate
|
||||
|
||||
|
||||
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
def __init__(
|
||||
self, max_cluster, llm_model, embd_model, prompt, max_token=512, threshold=0.1
|
||||
self,
|
||||
max_cluster,
|
||||
llm_model,
|
||||
embd_model,
|
||||
prompt,
|
||||
max_token=512,
|
||||
threshold=0.1,
|
||||
max_errors=3,
|
||||
):
|
||||
self._max_cluster = max_cluster
|
||||
self._llm_model = llm_model
|
||||
@ -43,31 +51,35 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
self._threshold = threshold
|
||||
self._prompt = prompt
|
||||
self._max_token = max_token
|
||||
self._max_errors = max(1, max_errors)
|
||||
self._error_count = 0
|
||||
|
||||
@timeout(60*20)
|
||||
@timeout(60 * 20)
|
||||
async def _chat(self, system, history, gen_conf):
|
||||
response = await trio.to_thread.run_sync(
|
||||
lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf)
|
||||
)
|
||||
cached = await trio.to_thread.run_sync(lambda: get_llm_cache(self._llm_model.llm_name, system, history, gen_conf))
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
if response:
|
||||
return response
|
||||
response = await trio.to_thread.run_sync(
|
||||
lambda: self._llm_model.chat(system, history, gen_conf)
|
||||
)
|
||||
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
||||
if response.find("**ERROR**") >= 0:
|
||||
raise Exception(response)
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf)
|
||||
)
|
||||
return response
|
||||
last_exc = None
|
||||
for attempt in range(3):
|
||||
try:
|
||||
response = await trio.to_thread.run_sync(lambda: self._llm_model.chat(system, history, gen_conf))
|
||||
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
||||
if response.find("**ERROR**") >= 0:
|
||||
raise Exception(response)
|
||||
await trio.to_thread.run_sync(lambda: set_llm_cache(self._llm_model.llm_name, system, response, history, gen_conf))
|
||||
return response
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
logging.warning("RAPTOR LLM call failed on attempt %d/3: %s", attempt + 1, exc)
|
||||
if attempt < 2:
|
||||
await trio.sleep(1 + attempt)
|
||||
|
||||
raise last_exc if last_exc else Exception("LLM chat failed without exception")
|
||||
|
||||
@timeout(20)
|
||||
async def _embedding_encode(self, txt):
|
||||
response = await trio.to_thread.run_sync(
|
||||
lambda: get_embed_cache(self._embd_model.llm_name, txt)
|
||||
)
|
||||
response = await trio.to_thread.run_sync(lambda: get_embed_cache(self._embd_model.llm_name, txt))
|
||||
if response is not None:
|
||||
return response
|
||||
embds, _ = await trio.to_thread.run_sync(lambda: self._embd_model.encode([txt]))
|
||||
@ -82,7 +94,6 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
n_clusters = np.arange(1, max_clusters)
|
||||
bics = []
|
||||
for n in n_clusters:
|
||||
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during get optimal clusters.")
|
||||
@ -101,7 +112,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
layers = [(0, len(chunks))]
|
||||
start, end = 0, len(chunks)
|
||||
|
||||
@timeout(60*20)
|
||||
@timeout(60 * 20)
|
||||
async def summarize(ck_idx: list[int]):
|
||||
nonlocal chunks
|
||||
|
||||
@ -111,47 +122,50 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
texts = [chunks[i][0] for i in ck_idx]
|
||||
len_per_chunk = int(
|
||||
(self._llm_model.max_length - self._max_token) / len(texts)
|
||||
)
|
||||
cluster_content = "\n".join(
|
||||
[truncate(t, max(1, len_per_chunk)) for t in texts]
|
||||
)
|
||||
async with chat_limiter:
|
||||
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
|
||||
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
|
||||
try:
|
||||
async with chat_limiter:
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before RAPTOR LLM call.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
cnt = await self._chat(
|
||||
"You're a helpful assistant.",
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._prompt.format(cluster_content=cluster_content),
|
||||
}
|
||||
],
|
||||
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
|
||||
)
|
||||
cnt = re.sub(
|
||||
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
|
||||
"",
|
||||
cnt,
|
||||
)
|
||||
logging.debug(f"SUM: {cnt}")
|
||||
|
||||
cnt = await self._chat(
|
||||
"You're a helpful assistant.",
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._prompt.format(
|
||||
cluster_content=cluster_content
|
||||
),
|
||||
}
|
||||
],
|
||||
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
|
||||
)
|
||||
cnt = re.sub(
|
||||
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
|
||||
"",
|
||||
cnt,
|
||||
)
|
||||
logging.debug(f"SUM: {cnt}")
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before RAPTOR embedding.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
embds = await self._embedding_encode(cnt)
|
||||
chunks.append((cnt, embds))
|
||||
embds = await self._embedding_encode(cnt)
|
||||
chunks.append((cnt, embds))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
self._error_count += 1
|
||||
warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}"
|
||||
logging.warning(warn_msg)
|
||||
if callback:
|
||||
callback(msg=warn_msg)
|
||||
if self._error_count >= self._max_errors:
|
||||
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
|
||||
|
||||
labels = []
|
||||
while end - start > 1:
|
||||
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during RAPTOR layer processing.")
|
||||
@ -161,11 +175,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
if len(embeddings) == 2:
|
||||
await summarize([start, start + 1])
|
||||
if callback:
|
||||
callback(
|
||||
msg="Cluster one layer: {} -> {}".format(
|
||||
end - start, len(chunks) - end
|
||||
)
|
||||
)
|
||||
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
|
||||
labels.extend([0, 0])
|
||||
layers.append((end, len(chunks)))
|
||||
start = end
|
||||
@ -199,17 +209,11 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
|
||||
nursery.start_soon(summarize, ck_idx)
|
||||
|
||||
assert len(chunks) - end == n_clusters, "{} vs. {}".format(
|
||||
len(chunks) - end, n_clusters
|
||||
)
|
||||
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
|
||||
labels.extend(lbls)
|
||||
layers.append((end, len(chunks)))
|
||||
if callback:
|
||||
callback(
|
||||
msg="Cluster one layer: {} -> {}".format(
|
||||
end - start, len(chunks) - end
|
||||
)
|
||||
)
|
||||
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
|
||||
start = end
|
||||
end = len(chunks)
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ def collect():
|
||||
logging.debug(doc_locations)
|
||||
if len(doc_locations) == 0:
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
return doc_locations
|
||||
|
||||
|
||||
|
||||
@ -20,33 +20,40 @@
|
||||
|
||||
|
||||
import copy
|
||||
import faulthandler
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import trio
|
||||
|
||||
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from common.log_utils import init_root_logger
|
||||
from common.config_utils import show_configs
|
||||
from common.data_source import BlobStorageConnector, NotionConnector, DiscordConnector, GoogleDriveConnector
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
import signal
|
||||
import trio
|
||||
import faulthandler
|
||||
from common.constants import FileSource, TaskStatus
|
||||
from common import settings
|
||||
from common.versions import get_ragflow_version
|
||||
from common.config_utils import show_configs
|
||||
from common.constants import FileSource, TaskStatus
|
||||
from common.data_source import (
|
||||
BlobStorageConnector,
|
||||
DiscordConnector,
|
||||
GoogleDriveConnector,
|
||||
JiraConnector,
|
||||
NotionConnector,
|
||||
)
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.data_source.confluence_connector import ConfluenceConnector
|
||||
from common.data_source.interfaces import CheckpointOutputWrapper
|
||||
from common.data_source.utils import load_all_docs_from_checkpoint_connector
|
||||
from common.data_source.config import INDEX_BATCH_SIZE
|
||||
from common.log_utils import init_root_logger
|
||||
from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc
|
||||
from common.versions import get_ragflow_version
|
||||
|
||||
MAX_CONCURRENT_TASKS = int(os.environ.get('MAX_CONCURRENT_TASKS', "5"))
|
||||
MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5"))
|
||||
task_limiter = trio.Semaphore(MAX_CONCURRENT_TASKS)
|
||||
|
||||
|
||||
@ -72,31 +79,32 @@ class SyncBase:
|
||||
min_update = min([doc.doc_updated_at for doc in document_batch])
|
||||
max_update = max([doc.doc_updated_at for doc in document_batch])
|
||||
next_update = max([next_update, max_update])
|
||||
docs = [{
|
||||
"id": doc.id,
|
||||
"connector_id": task["connector_id"],
|
||||
"source": self.SOURCE_NAME,
|
||||
"semantic_identifier": doc.semantic_identifier,
|
||||
"extension": doc.extension,
|
||||
"size_bytes": doc.size_bytes,
|
||||
"doc_updated_at": doc.doc_updated_at,
|
||||
"blob": doc.blob
|
||||
} for doc in document_batch]
|
||||
docs = [
|
||||
{
|
||||
"id": doc.id,
|
||||
"connector_id": task["connector_id"],
|
||||
"source": self.SOURCE_NAME,
|
||||
"semantic_identifier": doc.semantic_identifier,
|
||||
"extension": doc.extension,
|
||||
"size_bytes": doc.size_bytes,
|
||||
"doc_updated_at": doc.doc_updated_at,
|
||||
"blob": doc.blob,
|
||||
}
|
||||
for doc in document_batch
|
||||
]
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(task["kb_id"])
|
||||
err, dids = SyncLogsService.duplicate_and_parse(kb, docs, task["tenant_id"], f"{self.SOURCE_NAME}/{task['connector_id']}", task["auto_parse"])
|
||||
SyncLogsService.increase_docs(task["id"], min_update, max_update, len(docs), "\n".join(err), len(err))
|
||||
doc_num += len(docs)
|
||||
|
||||
logging.info("{} docs synchronized till {}".format(doc_num, next_update))
|
||||
prefix = "[Jira] " if self.SOURCE_NAME == FileSource.JIRA else ""
|
||||
logging.info(f"{prefix}{doc_num} docs synchronized till {next_update}")
|
||||
SyncLogsService.done(task["id"], task["connector_id"])
|
||||
task["poll_range_start"] = next_update
|
||||
|
||||
except Exception as ex:
|
||||
msg = '\n'.join([
|
||||
''.join(traceback.format_exception_only(None, ex)).strip(),
|
||||
''.join(traceback.format_exception(None, ex, ex.__traceback__)).strip()
|
||||
])
|
||||
msg = "\n".join(["".join(traceback.format_exception_only(None, ex)).strip(), "".join(traceback.format_exception(None, ex, ex.__traceback__)).strip()])
|
||||
SyncLogsService.update_by_id(task["id"], {"status": TaskStatus.FAIL, "full_exception_trace": msg, "error_msg": str(ex)})
|
||||
|
||||
SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"])
|
||||
@ -109,21 +117,16 @@ class S3(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.S3
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
self.connector = BlobStorageConnector(
|
||||
bucket_type=self.conf.get("bucket_type", "s3"),
|
||||
bucket_name=self.conf["bucket_name"],
|
||||
prefix=self.conf.get("prefix", "")
|
||||
)
|
||||
self.connector = BlobStorageConnector(bucket_type=self.conf.get("bucket_type", "s3"), bucket_name=self.conf["bucket_name"], prefix=self.conf.get("prefix", ""))
|
||||
self.connector.load_credentials(self.conf["credentials"])
|
||||
document_batch_generator = self.connector.load_from_state() if task["reindex"]=="1" or not task["poll_range_start"] \
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
||||
document_batch_generator = (
|
||||
self.connector.load_from_state()
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
||||
)
|
||||
|
||||
begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
||||
logging.info("Connect to {}: {}(prefix/{}) {}".format(self.conf.get("bucket_type", "s3"),
|
||||
self.conf["bucket_name"],
|
||||
self.conf.get("prefix", ""),
|
||||
begin_info
|
||||
))
|
||||
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
||||
logging.info("Connect to {}: {}(prefix/{}) {}".format(self.conf.get("bucket_type", "s3"), self.conf["bucket_name"], self.conf.get("prefix", ""), begin_info))
|
||||
return document_batch_generator
|
||||
|
||||
|
||||
@ -131,8 +134,8 @@ class Confluence(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.CONFLUENCE
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
from common.data_source.interfaces import StaticCredentialsProvider
|
||||
from common.data_source.config import DocumentSource
|
||||
from common.data_source.interfaces import StaticCredentialsProvider
|
||||
|
||||
self.connector = ConfluenceConnector(
|
||||
wiki_base=self.conf["wiki_base"],
|
||||
@ -141,11 +144,7 @@ class Confluence(SyncBase):
|
||||
# page_id=self.conf.get("page_id", ""),
|
||||
)
|
||||
|
||||
credentials_provider = StaticCredentialsProvider(
|
||||
tenant_id=task["tenant_id"],
|
||||
connector_name=DocumentSource.CONFLUENCE,
|
||||
credential_json=self.conf["credentials"]
|
||||
)
|
||||
credentials_provider = StaticCredentialsProvider(tenant_id=task["tenant_id"], connector_name=DocumentSource.CONFLUENCE, credential_json=self.conf["credentials"])
|
||||
self.connector.set_credentials_provider(credentials_provider)
|
||||
|
||||
# Determine the time range for synchronization based on reindex or poll_range_start
|
||||
@ -174,10 +173,13 @@ class Notion(SyncBase):
|
||||
async def _generate(self, task: dict):
|
||||
self.connector = NotionConnector(root_page_id=self.conf["root_page_id"])
|
||||
self.connector.load_credentials(self.conf["credentials"])
|
||||
document_generator = self.connector.load_from_state() if task["reindex"]=="1" or not task["poll_range_start"] \
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
||||
document_generator = (
|
||||
self.connector.load_from_state()
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
||||
)
|
||||
|
||||
begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
||||
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
||||
logging.info("Connect to Notion: root({}) {}".format(self.conf["root_page_id"], begin_info))
|
||||
return document_generator
|
||||
|
||||
@ -194,13 +196,16 @@ class Discord(SyncBase):
|
||||
server_ids=server_ids.split(",") if server_ids else [],
|
||||
channel_names=channel_names.split(",") if channel_names else [],
|
||||
start_date=datetime(1970, 1, 1, tzinfo=timezone.utc).strftime("%Y-%m-%d"),
|
||||
batch_size=self.conf.get("batch_size", 1024)
|
||||
batch_size=self.conf.get("batch_size", 1024),
|
||||
)
|
||||
self.connector.load_credentials(self.conf["credentials"])
|
||||
document_generator = self.connector.load_from_state() if task["reindex"]=="1" or not task["poll_range_start"] \
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
||||
document_generator = (
|
||||
self.connector.load_from_state()
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]
|
||||
else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp())
|
||||
)
|
||||
|
||||
begin_info = "totally" if task["reindex"]=="1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
||||
begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format(task["poll_range_start"])
|
||||
logging.info("Connect to Discord: servers({}), channel({}) {}".format(server_ids, channel_names, begin_info))
|
||||
return document_generator
|
||||
|
||||
@ -285,7 +290,7 @@ class GoogleDrive(SyncBase):
|
||||
admin_email = self.connector.primary_admin_email
|
||||
except RuntimeError:
|
||||
admin_email = "unknown"
|
||||
logging.info("Connect to Google Drive as %s %s", admin_email, begin_info)
|
||||
logging.info(f"Connect to Google Drive as {admin_email} {begin_info}")
|
||||
return document_batches()
|
||||
|
||||
def _persist_rotated_credentials(self, connector_id: str, credentials: dict[str, Any]) -> None:
|
||||
@ -303,7 +308,93 @@ class Jira(SyncBase):
|
||||
SOURCE_NAME: str = FileSource.JIRA
|
||||
|
||||
async def _generate(self, task: dict):
|
||||
pass
|
||||
connector_kwargs = {
|
||||
"jira_base_url": self.conf["base_url"],
|
||||
"project_key": self.conf.get("project_key"),
|
||||
"jql_query": self.conf.get("jql_query"),
|
||||
"batch_size": self.conf.get("batch_size", INDEX_BATCH_SIZE),
|
||||
"include_comments": self.conf.get("include_comments", True),
|
||||
"include_attachments": self.conf.get("include_attachments", False),
|
||||
"labels_to_skip": self._normalize_list(self.conf.get("labels_to_skip")),
|
||||
"comment_email_blacklist": self._normalize_list(self.conf.get("comment_email_blacklist")),
|
||||
"scoped_token": self.conf.get("scoped_token", False),
|
||||
"attachment_size_limit": self.conf.get("attachment_size_limit"),
|
||||
"timezone_offset": self.conf.get("timezone_offset"),
|
||||
}
|
||||
|
||||
self.connector = JiraConnector(**connector_kwargs)
|
||||
|
||||
credentials = self.conf.get("credentials")
|
||||
if not credentials:
|
||||
raise ValueError("Jira connector is missing credentials.")
|
||||
|
||||
self.connector.load_credentials(credentials)
|
||||
self.connector.validate_connector_settings()
|
||||
|
||||
if task["reindex"] == "1" or not task["poll_range_start"]:
|
||||
start_time = 0.0
|
||||
begin_info = "totally"
|
||||
else:
|
||||
start_time = task["poll_range_start"].timestamp()
|
||||
begin_info = f"from {task['poll_range_start']}"
|
||||
|
||||
end_time = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE
|
||||
try:
|
||||
batch_size = int(raw_batch_size)
|
||||
except (TypeError, ValueError):
|
||||
batch_size = INDEX_BATCH_SIZE
|
||||
if batch_size <= 0:
|
||||
batch_size = INDEX_BATCH_SIZE
|
||||
|
||||
def document_batches():
|
||||
checkpoint = self.connector.build_dummy_checkpoint()
|
||||
pending_docs = []
|
||||
iterations = 0
|
||||
iteration_limit = 100_000
|
||||
|
||||
while checkpoint.has_more:
|
||||
wrapper = CheckpointOutputWrapper()
|
||||
generator = wrapper(
|
||||
self.connector.load_from_checkpoint(
|
||||
start_time,
|
||||
end_time,
|
||||
checkpoint,
|
||||
)
|
||||
)
|
||||
for document, failure, next_checkpoint in generator:
|
||||
if failure is not None:
|
||||
logging.warning(
|
||||
f"[Jira] Jira connector failure: {getattr(failure, 'failure_message', failure)}"
|
||||
)
|
||||
continue
|
||||
if document is not None:
|
||||
pending_docs.append(document)
|
||||
if len(pending_docs) >= batch_size:
|
||||
yield pending_docs
|
||||
pending_docs = []
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
iterations += 1
|
||||
if iterations > iteration_limit:
|
||||
logging.error(f"[Jira] Task {task.get('id')} exceeded iteration limit ({iteration_limit}).")
|
||||
raise RuntimeError("Too many iterations while loading Jira documents.")
|
||||
|
||||
if pending_docs:
|
||||
yield pending_docs
|
||||
|
||||
logging.info(f"[Jira] Connect to Jira {connector_kwargs['jira_base_url']} {begin_info}")
|
||||
return document_batches()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_list(values: Any) -> list[str] | None:
|
||||
if values is None:
|
||||
return None
|
||||
if isinstance(values, str):
|
||||
values = [item.strip() for item in values.split(",")]
|
||||
return [str(value).strip() for value in values if value is not None and str(value).strip()]
|
||||
|
||||
|
||||
class SharePoint(SyncBase):
|
||||
@ -337,9 +428,10 @@ func_factory = {
|
||||
FileSource.JIRA: Jira,
|
||||
FileSource.SHAREPOINT: SharePoint,
|
||||
FileSource.SLACK: Slack,
|
||||
FileSource.TEAMS: Teams
|
||||
FileSource.TEAMS: Teams,
|
||||
}
|
||||
|
||||
|
||||
async def dispatch_tasks():
|
||||
async with trio.open_nursery() as nursery:
|
||||
while True:
|
||||
@ -385,7 +477,7 @@ async def main():
|
||||
__/ |
|
||||
|___/
|
||||
""")
|
||||
logging.info(f'RAGFlow version: {get_ragflow_version()}')
|
||||
logging.info(f"RAGFlow version: {get_ragflow_version()}")
|
||||
show_configs()
|
||||
settings.init_settings()
|
||||
if sys.platform != "win32":
|
||||
|
||||
@ -359,7 +359,7 @@ async def build_chunks(task, progress_callback):
|
||||
task_canceled = has_canceled(task["id"])
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
return None
|
||||
if settings.retriever.tag_content(tenant_id, kb_ids, d, all_tags, topn_tags=topn_tags, S=S) and len(d[TAG_FLD]) > 0:
|
||||
examples.append({"content": d["content_with_weight"], TAG_FLD: d[TAG_FLD]})
|
||||
else:
|
||||
@ -417,6 +417,7 @@ def build_TOC(task, docs, progress_callback):
|
||||
d["page_num_int"] = [100000000]
|
||||
d["id"] = xxhash.xxh64((d["content_with_weight"] + str(d["doc_id"])).encode("utf-8", "surrogatepass")).hexdigest()
|
||||
return d
|
||||
return None
|
||||
|
||||
|
||||
def init_kb(row, vector_size: int):
|
||||
@ -441,7 +442,7 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
tk_count = 0
|
||||
if len(tts) == len(cnts):
|
||||
vts, c = await trio.to_thread.run_sync(lambda: mdl.encode(tts[0: 1]))
|
||||
tts = np.concatenate([vts[0] for _ in range(len(tts))], axis=0)
|
||||
tts = np.tile(vts[0], (len(cnts), 1))
|
||||
tk_count += c
|
||||
|
||||
@timeout(60)
|
||||
@ -464,8 +465,10 @@ async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
if not filename_embd_weight:
|
||||
filename_embd_weight = 0.1
|
||||
title_w = float(filename_embd_weight)
|
||||
vects = (title_w * tts + (1 - title_w) *
|
||||
cnts) if len(tts) == len(cnts) else cnts
|
||||
if tts.ndim == 2 and cnts.ndim == 2 and tts.shape == cnts.shape:
|
||||
vects = title_w * tts + (1 - title_w) * cnts
|
||||
else:
|
||||
vects = cnts
|
||||
|
||||
assert len(vects) == len(docs)
|
||||
vector_size = 0
|
||||
@ -648,6 +651,8 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
|
||||
res = []
|
||||
tk_count = 0
|
||||
max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3))
|
||||
|
||||
async def generate(chunks, did):
|
||||
nonlocal tk_count, res
|
||||
raptor = Raptor(
|
||||
@ -657,6 +662,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
raptor_config["prompt"],
|
||||
raptor_config["max_token"],
|
||||
raptor_config["threshold"],
|
||||
max_errors=max_errors,
|
||||
)
|
||||
original_length = len(chunks)
|
||||
chunks = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])
|
||||
@ -719,7 +725,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
return False
|
||||
if b % 128 == 0:
|
||||
progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="")
|
||||
if doc_store_result:
|
||||
@ -737,7 +743,7 @@ async def insert_es(task_id, task_tenant_id, task_dataset_id, chunks, progress_c
|
||||
for chunk_id in chunk_ids:
|
||||
nursery.start_soon(delete_image, task_dataset_id, chunk_id)
|
||||
progress_callback(-1, msg=f"Chunk updates failed since task {task_id} is unknown.")
|
||||
return
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@ -67,6 +67,8 @@ class RAGFlowAzureSpnBlob:
|
||||
logging.exception(f"Fail put {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return None
|
||||
return None
|
||||
|
||||
def rm(self, bucket, fnm):
|
||||
try:
|
||||
@ -84,7 +86,7 @@ class RAGFlowAzureSpnBlob:
|
||||
logging.exception(f"fail get {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
def obj_exist(self, bucket, fnm):
|
||||
try:
|
||||
@ -102,4 +104,4 @@ class RAGFlowAzureSpnBlob:
|
||||
logging.exception(f"fail get {bucket}/{fnm}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
@ -241,23 +241,23 @@ class DocStoreConnection(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def getTotal(self, res):
|
||||
def get_total(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getChunkIds(self, res):
|
||||
def get_chunk_ids(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getHighlight(self, res, keywords: list[str], fieldnm: str):
|
||||
def get_highlight(self, res, keywords: list[str], fieldnm: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def getAggregation(self, res, fieldnm: str):
|
||||
def get_aggregation(self, res, fieldnm: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
|
||||
@ -471,12 +471,12 @@ class ESConnection(DocStoreConnection):
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def getTotal(self, res):
|
||||
def get_total(self, res):
|
||||
if isinstance(res["hits"]["total"], type({})):
|
||||
return res["hits"]["total"]["value"]
|
||||
return res["hits"]["total"]
|
||||
|
||||
def getChunkIds(self, res):
|
||||
def get_chunk_ids(self, res):
|
||||
return [d["_id"] for d in res["hits"]["hits"]]
|
||||
|
||||
def __getSource(self, res):
|
||||
@ -487,7 +487,7 @@ class ESConnection(DocStoreConnection):
|
||||
rr.append(d["_source"])
|
||||
return rr
|
||||
|
||||
def getFields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
res_fields = {}
|
||||
if not fields:
|
||||
return {}
|
||||
@ -509,7 +509,7 @@ class ESConnection(DocStoreConnection):
|
||||
res_fields[d["id"]] = m
|
||||
return res_fields
|
||||
|
||||
def getHighlight(self, res, keywords: list[str], fieldnm: str):
|
||||
def get_highlight(self, res, keywords: list[str], fieldnm: str):
|
||||
ans = {}
|
||||
for d in res["hits"]["hits"]:
|
||||
hlts = d.get("highlight")
|
||||
@ -534,7 +534,7 @@ class ESConnection(DocStoreConnection):
|
||||
|
||||
return ans
|
||||
|
||||
def getAggregation(self, res, fieldnm: str):
|
||||
def get_aggregation(self, res, fieldnm: str):
|
||||
agg_field = "aggs_" + fieldnm
|
||||
if "aggregations" not in res or agg_field not in res["aggregations"]:
|
||||
return list()
|
||||
|
||||
@ -470,7 +470,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
df_list.append(kb_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = concat_dataframes(df_list, ["id"])
|
||||
res_fields = self.getFields(res, res.columns.tolist())
|
||||
res_fields = self.get_fields(res, res.columns.tolist())
|
||||
return res_fields.get(chunkId, None)
|
||||
|
||||
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
||||
@ -599,7 +599,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
col_to_remove = list(removeValue.keys())
|
||||
row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df()
|
||||
logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
|
||||
row_to_opt = self.getFields(row_to_opt, col_to_remove)
|
||||
row_to_opt = self.get_fields(row_to_opt, col_to_remove)
|
||||
for id, old_v in row_to_opt.items():
|
||||
for k, remove_v in removeValue.items():
|
||||
if remove_v in old_v[k]:
|
||||
@ -639,17 +639,17 @@ class InfinityConnection(DocStoreConnection):
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def getTotal(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
|
||||
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
|
||||
if isinstance(res, tuple):
|
||||
return res[1]
|
||||
return len(res)
|
||||
|
||||
def getChunkIds(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
|
||||
def get_chunk_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
return list(res["id"])
|
||||
|
||||
def getFields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
|
||||
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
if not fields:
|
||||
@ -690,7 +690,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
|
||||
return res2.set_index("id").to_dict(orient="index")
|
||||
|
||||
def getHighlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
|
||||
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
ans = {}
|
||||
@ -732,7 +732,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
ans[id] = txt
|
||||
return ans
|
||||
|
||||
def getAggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
|
||||
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
|
||||
"""
|
||||
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
|
||||
"""
|
||||
|
||||
@ -92,7 +92,7 @@ class RAGFlowMinio:
|
||||
logging.exception(f"Fail to get {bucket}/{filename}")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
def obj_exist(self, bucket, filename, tenant_id=None):
|
||||
try:
|
||||
@ -130,7 +130,7 @@ class RAGFlowMinio:
|
||||
logging.exception(f"Fail to get_presigned {bucket}/{fnm}:")
|
||||
self.__open__()
|
||||
time.sleep(1)
|
||||
return
|
||||
return None
|
||||
|
||||
def remove_bucket(self, bucket):
|
||||
try:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user