mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-02 16:45:08 +08:00
Compare commits
3 Commits
57edc215d7
...
6cd1824a77
| Author | SHA1 | Date | |
|---|---|---|---|
| 6cd1824a77 | |||
| 2844700dc4 | |||
| f8fd1ea7e1 |
@ -157,7 +157,7 @@ async def add_llm():
|
||||
elif factory == "Bedrock":
|
||||
# For Bedrock, due to its special authentication method
|
||||
# Assemble bedrock_ak, bedrock_sk, bedrock_region
|
||||
api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"])
|
||||
api_key = apikey_json(["auth_mode", "bedrock_ak", "bedrock_sk", "bedrock_region", "aws_role_arn"])
|
||||
|
||||
elif factory == "LocalAI":
|
||||
llm_name += "___LocalAI"
|
||||
|
||||
@ -33,7 +33,7 @@ from api.db.services.dialog_service import DialogService, async_ask, async_chat,
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from common.metadata_utils import apply_meta_data_filter, convert_conditions, meta_filter
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
@ -129,11 +129,33 @@ async def chat_completion(tenant_id, chat_id):
|
||||
req = {"question": ""}
|
||||
if not req.get("session_id"):
|
||||
req["question"] = ""
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
dia = DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value)
|
||||
if not dia:
|
||||
return get_error_data_result(f"You don't own the chat {chat_id}")
|
||||
dia = dia[0]
|
||||
if req.get("session_id"):
|
||||
if not ConversationService.query(id=req["session_id"], dialog_id=chat_id):
|
||||
return get_error_data_result(f"You don't own the session {req['session_id']}")
|
||||
|
||||
metadata_condition = req.get("metadata_condition") or {}
|
||||
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||
return get_error_data_result(message="metadata_condition must be an object.")
|
||||
|
||||
if metadata_condition and req.get("question"):
|
||||
metas = DocumentService.get_meta_by_kbs(dia.kb_ids or [])
|
||||
filtered_doc_ids = meta_filter(
|
||||
metas,
|
||||
convert_conditions(metadata_condition),
|
||||
metadata_condition.get("logic", "and"),
|
||||
)
|
||||
if metadata_condition.get("conditions") and not filtered_doc_ids:
|
||||
filtered_doc_ids = ["-999"]
|
||||
|
||||
if filtered_doc_ids:
|
||||
req["doc_ids"] = ",".join(filtered_doc_ids)
|
||||
else:
|
||||
req.pop("doc_ids", None)
|
||||
|
||||
if req.get("stream", True):
|
||||
resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
@ -196,7 +218,19 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
{"role": "user", "content": "Can you tell me how to install neovim"},
|
||||
],
|
||||
stream=stream,
|
||||
extra_body={"reference": reference}
|
||||
extra_body={
|
||||
"reference": reference,
|
||||
"metadata_condition": {
|
||||
"logic": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"name": "author",
|
||||
"comparison_operator": "is",
|
||||
"value": "bob"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if stream:
|
||||
@ -212,7 +246,11 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
"""
|
||||
req = await get_request_json()
|
||||
|
||||
need_reference = bool(req.get("reference", False))
|
||||
extra_body = req.get("extra_body") or {}
|
||||
if extra_body and not isinstance(extra_body, dict):
|
||||
return get_error_data_result("extra_body must be an object.")
|
||||
|
||||
need_reference = bool(extra_body.get("reference", False))
|
||||
|
||||
messages = req.get("messages", [])
|
||||
# To prevent empty [] input
|
||||
@ -230,6 +268,22 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
return get_error_data_result(f"You don't own the chat {chat_id}")
|
||||
dia = dia[0]
|
||||
|
||||
metadata_condition = extra_body.get("metadata_condition") or {}
|
||||
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||
return get_error_data_result(message="metadata_condition must be an object.")
|
||||
|
||||
doc_ids_str = None
|
||||
if metadata_condition:
|
||||
metas = DocumentService.get_meta_by_kbs(dia.kb_ids or [])
|
||||
filtered_doc_ids = meta_filter(
|
||||
metas,
|
||||
convert_conditions(metadata_condition),
|
||||
metadata_condition.get("logic", "and"),
|
||||
)
|
||||
if metadata_condition.get("conditions") and not filtered_doc_ids:
|
||||
filtered_doc_ids = ["-999"]
|
||||
doc_ids_str = ",".join(filtered_doc_ids) if filtered_doc_ids else None
|
||||
|
||||
# Filter system and non-sense assistant messages
|
||||
msg = []
|
||||
for m in messages:
|
||||
@ -277,14 +331,17 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
}
|
||||
|
||||
try:
|
||||
async for ans in async_chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
||||
chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
|
||||
if doc_ids_str:
|
||||
chat_kwargs["doc_ids"] = doc_ids_str
|
||||
async for ans in async_chat(dia, msg, True, **chat_kwargs):
|
||||
last_ans = ans
|
||||
answer = ans["answer"]
|
||||
|
||||
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
|
||||
if reasoning_match:
|
||||
reasoning_part = reasoning_match.group(1)
|
||||
content_part = answer[reasoning_match.end():]
|
||||
content_part = answer[reasoning_match.end() :]
|
||||
else:
|
||||
reasoning_part = ""
|
||||
content_part = answer
|
||||
@ -329,8 +386,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
response["choices"][0]["delta"]["content"] = None
|
||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||
response["choices"][0]["finish_reason"] = "stop"
|
||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used,
|
||||
"total_tokens": len(prompt) + token_used}
|
||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
|
||||
if need_reference:
|
||||
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
|
||||
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
|
||||
@ -345,7 +401,10 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
return resp
|
||||
else:
|
||||
answer = None
|
||||
async for ans in async_chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
||||
chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
|
||||
if doc_ids_str:
|
||||
chat_kwargs["doc_ids"] = doc_ids_str
|
||||
async for ans in async_chat(dia, msg, False, **chat_kwargs):
|
||||
# focus answer content only
|
||||
answer = ans
|
||||
break
|
||||
|
||||
@ -48,6 +48,7 @@ This API follows the same request and response format as OpenAI's API. It allows
|
||||
- `"model"`: `string`
|
||||
- `"messages"`: `object list`
|
||||
- `"stream"`: `boolean`
|
||||
- `"extra_body"`: `object` (optional)
|
||||
|
||||
##### Request example
|
||||
|
||||
@ -59,7 +60,20 @@ curl --request POST \
|
||||
--data '{
|
||||
"model": "model",
|
||||
"messages": [{"role": "user", "content": "Say this is a test!"}],
|
||||
"stream": true
|
||||
"stream": true,
|
||||
"extra_body": {
|
||||
"reference": true,
|
||||
"metadata_condition": {
|
||||
"logic": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"name": "author",
|
||||
"comparison_operator": "is",
|
||||
"value": "bob"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
@ -74,6 +88,11 @@ curl --request POST \
|
||||
- `stream` (*Body parameter*) `boolean`
|
||||
Whether to receive the response as a stream. Set this to `false` explicitly if you prefer to receive the entire response in one go instead of as a stream.
|
||||
|
||||
- `extra_body` (*Body parameter*) `object`
|
||||
Extra request parameters:
|
||||
- `reference`: `boolean` - include reference in the final chunk (stream) or in the final message (non-stream).
|
||||
- `metadata_condition`: `object` - metadata filter conditions applied to retrieval results.
|
||||
|
||||
#### Response
|
||||
|
||||
Stream:
|
||||
@ -3185,6 +3204,7 @@ Asks a specified chat assistant a question to start an AI-powered conversation.
|
||||
- `"stream"`: `boolean`
|
||||
- `"session_id"`: `string` (optional)
|
||||
- `"user_id`: `string` (optional)
|
||||
- `"metadata_condition"`: `object` (optional)
|
||||
|
||||
##### Request example
|
||||
|
||||
@ -3207,7 +3227,17 @@ curl --request POST \
|
||||
{
|
||||
"question": "Who are you",
|
||||
"stream": true,
|
||||
"session_id":"9fa7691cb85c11ef9c5f0242ac120005"
|
||||
"session_id":"9fa7691cb85c11ef9c5f0242ac120005",
|
||||
"metadata_condition": {
|
||||
"logic": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"name": "author",
|
||||
"comparison_operator": "is",
|
||||
"value": "bob"
|
||||
}
|
||||
]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
@ -3225,6 +3255,13 @@ curl --request POST \
|
||||
The ID of session. If it is not provided, a new session will be generated.
|
||||
- `"user_id"`: (*Body parameter*), `string`
|
||||
The optional user-defined ID. Valid *only* when no `session_id` is provided.
|
||||
- `"metadata_condition"`: (*Body parameter*), `object`
|
||||
Optional metadata filter conditions applied to retrieval results.
|
||||
- `logic`: `string`, one of `and` / `or`
|
||||
- `conditions`: `list[object]` where each condition contains:
|
||||
- `name`: `string` metadata key
|
||||
- `comparison_operator`: `string` (e.g. `is`, `not is`, `contains`, `not contains`, `start with`, `end with`, `empty`, `not empty`, `>`, `<`, `≥`, `≤`)
|
||||
- `value`: `string|number|boolean` (optional for `empty`/`not empty`)
|
||||
|
||||
#### Response
|
||||
|
||||
|
||||
@ -1217,11 +1217,7 @@ class LiteLLMBase(ABC):
|
||||
self.toolcall_sessions = {}
|
||||
|
||||
# Factory specific fields
|
||||
if self.provider == SupportedLiteLLMProvider.Bedrock:
|
||||
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
||||
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
||||
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
||||
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||
if self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||
self.api_key = json.loads(key).get("api_key", "")
|
||||
self.provider_order = json.loads(key).get("provider_order", "")
|
||||
elif self.provider == SupportedLiteLLMProvider.Azure_OpenAI:
|
||||
@ -1624,17 +1620,38 @@ class LiteLLMBase(ABC):
|
||||
if self.provider in FACTORY_DEFAULT_BASE_URL:
|
||||
completion_args.update({"api_base": self.base_url})
|
||||
elif self.provider == SupportedLiteLLMProvider.Bedrock:
|
||||
import boto3
|
||||
|
||||
completion_args.pop("api_key", None)
|
||||
completion_args.pop("api_base", None)
|
||||
bedrock_credentials = { "aws_region_name": self.bedrock_region }
|
||||
if self.bedrock_ak and self.bedrock_sk:
|
||||
bedrock_credentials["aws_access_key_id"] = self.bedrock_ak
|
||||
bedrock_credentials["aws_secret_access_key"] = self.bedrock_sk
|
||||
|
||||
bedrock_key = json.loads(self.api_key)
|
||||
mode = bedrock_key.get("auth_mode")
|
||||
if not mode:
|
||||
logging.error("Bedrock auth_mode is not provided in the key")
|
||||
raise ValueError("Bedrock auth_mode must be provided in the key")
|
||||
|
||||
bedrock_region = bedrock_key.get("bedrock_region")
|
||||
bedrock_credentials = {"bedrock_region": bedrock_region}
|
||||
|
||||
if mode == "access_key_secret":
|
||||
bedrock_credentials["aws_access_key_id"] = bedrock_key.get("bedrock_ak")
|
||||
bedrock_credentials["aws_secret_access_key"] = bedrock_key.get("bedrock_sk")
|
||||
elif mode == "iam_role":
|
||||
aws_role_arn = bedrock_key.get("aws_role_arn")
|
||||
sts_client = boto3.client("sts", region_name=bedrock_region)
|
||||
resp = sts_client.assume_role(RoleArn=aws_role_arn, RoleSessionName="BedrockSession")
|
||||
creds = resp["Credentials"]
|
||||
bedrock_credentials["aws_access_key_id"] = creds["AccessKeyId"]
|
||||
bedrock_credentials["aws_secret_access_key"] = creds["SecretAccessKey"]
|
||||
bedrock_credentials["aws_session_token"] = creds["SessionToken"]
|
||||
|
||||
completion_args.update(
|
||||
{
|
||||
"bedrock_credentials": bedrock_credentials,
|
||||
}
|
||||
)
|
||||
|
||||
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||
if self.provider_order:
|
||||
|
||||
|
||||
@ -463,20 +463,44 @@ class BedrockEmbed(Base):
|
||||
|
||||
def __init__(self, key, model_name, **kwargs):
|
||||
import boto3
|
||||
# `key` protocol (backend stores as JSON string in `api_key`):
|
||||
# - Must decode into a dict.
|
||||
# - Required: `auth_mode`, `bedrock_region`.
|
||||
# - Supported auth modes:
|
||||
# - "access_key_secret": requires `bedrock_ak` + `bedrock_sk`.
|
||||
# - "iam_role": requires `aws_role_arn` and assumes role via STS.
|
||||
# - else: treated as "assume_role" (default AWS credential chain).
|
||||
key = json.loads(key)
|
||||
mode = key.get("auth_mode")
|
||||
if not mode:
|
||||
logging.error("Bedrock auth_mode is not provided in the key")
|
||||
raise ValueError("Bedrock auth_mode must be provided in the key")
|
||||
|
||||
self.bedrock_region = key.get("bedrock_region")
|
||||
|
||||
self.bedrock_ak = json.loads(key).get("bedrock_ak", "")
|
||||
self.bedrock_sk = json.loads(key).get("bedrock_sk", "")
|
||||
self.bedrock_region = json.loads(key).get("bedrock_region", "")
|
||||
self.model_name = model_name
|
||||
self.is_amazon = self.model_name.split(".")[0] == "amazon"
|
||||
self.is_cohere = self.model_name.split(".")[0] == "cohere"
|
||||
|
||||
if self.bedrock_ak == "" or self.bedrock_sk == "":
|
||||
# Try to create a client using the default credentials if ak/sk are not provided.
|
||||
# Must provide a region.
|
||||
self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region)
|
||||
else:
|
||||
|
||||
if mode == "access_key_secret":
|
||||
self.bedrock_ak = key.get("bedrock_ak")
|
||||
self.bedrock_sk = key.get("bedrock_sk")
|
||||
self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk)
|
||||
elif mode == "iam_role":
|
||||
self.aws_role_arn = key.get("aws_role_arn")
|
||||
sts_client = boto3.client("sts", region_name=self.bedrock_region)
|
||||
resp = sts_client.assume_role(RoleArn=self.aws_role_arn, RoleSessionName="BedrockSession")
|
||||
creds = resp["Credentials"]
|
||||
|
||||
self.client = boto3.client(
|
||||
service_name="bedrock-runtime",
|
||||
aws_access_key_id=creds["AccessKeyId"],
|
||||
aws_secret_access_key=creds["SecretAccessKey"],
|
||||
aws_session_token=creds["SessionToken"],
|
||||
)
|
||||
else: # assume_role
|
||||
self.client = boto3.client("bedrock-runtime", region_name=self.bedrock_region)
|
||||
|
||||
|
||||
def encode(self, texts: list):
|
||||
texts = [truncate(t, 8196) for t in texts]
|
||||
|
||||
@ -787,6 +787,15 @@ This auto-tagging feature enhances retrieval by adding another layer of domain-s
|
||||
deleteModel: 'Delete model',
|
||||
bedrockCredentialsHint:
|
||||
'Tip: Leave Access Key / Secret Key blank to use AWS IAM authentication.',
|
||||
awsAuthModeAccessKeySecret: 'Access Key',
|
||||
awsAuthModeIamRole: 'IAM Role',
|
||||
awsAuthModeAssumeRole: 'Assume Role',
|
||||
awsAccessKeyId: 'AWS Access Key ID',
|
||||
awsSecretAccessKey: 'AWS Secret Access Key',
|
||||
awsRoleArn: 'AWS Role ARN',
|
||||
awsRoleArnMessage: 'Please enter AWS Role ARN',
|
||||
awsAssumeRoleTip:
|
||||
'If you select this mode, the Amazon EC2 instance will assume its existing role to access AWS services. No additional credentials are required.',
|
||||
modelEmptyTip:
|
||||
'No models available. <br>Please add models from the panel on the right.',
|
||||
sourceEmptyTip: 'No data sources added yet. Select one below to connect.',
|
||||
@ -1101,6 +1110,9 @@ Example: Virtual Hosted Style`,
|
||||
mcp: 'MCP',
|
||||
mineru: {
|
||||
modelNameRequired: 'Model name is required',
|
||||
apiServerRequired: 'MinerU API Server Configuration is required',
|
||||
serverUrlBackendLimit:
|
||||
'MinerU Server URL Address is only available for the HTTP client backend',
|
||||
apiserver: 'MinerU API Server Configuration',
|
||||
outputDir: 'MinerU Output Directory Path',
|
||||
backend: 'MinerU Processing Backend Type',
|
||||
@ -1112,6 +1124,11 @@ Example: Virtual Hosted Style`,
|
||||
vlmTransformers: 'Vision Language Model with Transformers',
|
||||
vlmVllmEngine: 'Vision Language Model with vLLM Engine',
|
||||
vlmHttpClient: 'Vision Language Model via HTTP Client',
|
||||
vlmMlxEngine: 'Vision Language Model with MLX Engine',
|
||||
vlmVllmAsyncEngine:
|
||||
'Vision Language Model with vLLM Async Engine (Experimental)',
|
||||
vlmLmdeployEngine:
|
||||
'Vision Language Model with LMDeploy Engine (Experimental)',
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@ -546,6 +546,15 @@ export default {
|
||||
profileDescription: '在此更新您的照片和個人詳細信息。',
|
||||
bedrockCredentialsHint:
|
||||
'提示:Access Key / Secret Key 可留空,以啟用 AWS IAM 自動驗證。',
|
||||
awsAuthModeAccessKeySecret: 'Access Key 和 Secret',
|
||||
awsAuthModeIamRole: 'IAM Role',
|
||||
awsAuthModeAssumeRole: 'Assume Role',
|
||||
awsAccessKeyId: 'AWS Access Key ID',
|
||||
awsSecretAccessKey: 'AWS Secret Access Key',
|
||||
awsRoleArn: 'AWS Role ARN',
|
||||
awsRoleArnMessage: '請輸入 AWS Role ARN',
|
||||
awsAssumeRoleTip:
|
||||
'選擇此模式後,EC2 執行個體將使用其既有的 IAM Role 存取 AWS 服務,無需額外憑證。',
|
||||
maxTokens: '最大token數',
|
||||
maxTokensMessage: '最大token數是必填項',
|
||||
maxTokensTip:
|
||||
|
||||
@ -876,6 +876,15 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
|
||||
bedrockAKMessage: '请输入 ACCESS KEY',
|
||||
addBedrockSK: 'SECRET KEY',
|
||||
bedrockSKMessage: '请输入 SECRET KEY',
|
||||
awsAuthModeAccessKeySecret: 'Access Key 和 Secret',
|
||||
awsAuthModeIamRole: 'IAM Role',
|
||||
awsAuthModeAssumeRole: 'Assume Role',
|
||||
awsAccessKeyId: 'AWS Access Key ID',
|
||||
awsSecretAccessKey: 'AWS Secret Access Key',
|
||||
awsRoleArn: 'AWS Role ARN',
|
||||
awsRoleArnMessage: '请输入 AWS Role ARN',
|
||||
awsAssumeRoleTip:
|
||||
'选择此模式后,EC2 实例将使用其已有的 IAM Role 访问 AWS 服务,无需额外的凭证。',
|
||||
bedrockRegion: 'AWS Region',
|
||||
bedrockRegionMessage: '请选择!',
|
||||
'us-east-1': '美国东部 (弗吉尼亚北部)',
|
||||
@ -950,6 +959,8 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
|
||||
mcp: 'MCP',
|
||||
mineru: {
|
||||
modelNameRequired: '模型名称为必填项',
|
||||
apiServerRequired: 'MinerU API服务器配置为必填项',
|
||||
serverUrlBackendLimit: '仅在backend 为vlm-http-client 时可填写',
|
||||
apiserver: 'MinerU API服务器配置',
|
||||
outputDir: 'MinerU输出目录路径',
|
||||
backend: 'MinerU处理后端类型',
|
||||
@ -961,6 +972,9 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
|
||||
vlmTransformers: '基于Transformers的视觉语言模型',
|
||||
vlmVllmEngine: '基于vLLM引擎的视觉语言模型',
|
||||
vlmHttpClient: '通过HTTP客户端连接的视觉语言模型',
|
||||
vlmMlxEngine: '基于MLX引擎的视觉语言模型',
|
||||
vlmVllmAsyncEngine: '基于vLLM异步引擎的视觉语言模型(实验性)',
|
||||
vlmLmdeployEngine: '基于LMDeploy引擎的视觉语言模型(实验性)',
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@ -472,10 +472,13 @@ export const useSubmitMinerU = () => {
|
||||
|
||||
const onMineruOk = useCallback(
|
||||
async (payload: MinerUFormValues) => {
|
||||
const cfg = {
|
||||
const cfg: any = {
|
||||
...payload,
|
||||
mineru_delete_output: payload.mineru_delete_output ?? true ? '1' : '0',
|
||||
};
|
||||
if (payload.mineru_backend !== 'vlm-http-client') {
|
||||
delete cfg.mineru_server_url;
|
||||
}
|
||||
const req: IAddLlmRequestBody = {
|
||||
llm_factory: LLMFactory.MinerU,
|
||||
llm_name: payload.llm_name,
|
||||
|
||||
@ -1,15 +1,25 @@
|
||||
import { useTranslate } from '@/hooks/common-hooks';
|
||||
import { IModalProps } from '@/interfaces/common';
|
||||
import { IAddLlmRequestBody } from '@/interfaces/request/llm';
|
||||
import { Form, Input, InputNumber, Modal, Select, Typography } from 'antd';
|
||||
import { useMemo } from 'react';
|
||||
import {
|
||||
Form,
|
||||
Input,
|
||||
InputNumber,
|
||||
Modal,
|
||||
Segmented,
|
||||
Select,
|
||||
Typography,
|
||||
} from 'antd';
|
||||
import { useMemo, useState } from 'react';
|
||||
import { LLMHeader } from '../../components/llm-header';
|
||||
import { BedrockRegionList } from '../../constant';
|
||||
|
||||
type FieldType = IAddLlmRequestBody & {
|
||||
auth_mode?: 'access_key_secret' | 'iam_role' | 'assume_role';
|
||||
bedrock_ak: string;
|
||||
bedrock_sk: string;
|
||||
bedrock_region: string;
|
||||
aws_role_arn?: string;
|
||||
};
|
||||
|
||||
const { Option } = Select;
|
||||
@ -23,6 +33,8 @@ const BedrockModal = ({
|
||||
llmFactory,
|
||||
}: IModalProps<IAddLlmRequestBody> & { llmFactory: string }) => {
|
||||
const [form] = Form.useForm<FieldType>();
|
||||
const [authMode, setAuthMode] =
|
||||
useState<FieldType['auth_mode']>('access_key_secret');
|
||||
|
||||
const { t } = useTranslate('setting');
|
||||
const options = useMemo(
|
||||
@ -33,13 +45,32 @@ const BedrockModal = ({
|
||||
const handleOk = async () => {
|
||||
const values = await form.validateFields();
|
||||
|
||||
// Only submit fields related to the active auth mode.
|
||||
const cleanedValues: Record<string, any> = { ...values };
|
||||
|
||||
const fieldsByMode: Record<string, string[]> = {
|
||||
access_key_secret: ['bedrock_ak', 'bedrock_sk'],
|
||||
iam_role: ['aws_role_arn'],
|
||||
assume_role: [],
|
||||
};
|
||||
|
||||
cleanedValues.auth_mode = authMode;
|
||||
|
||||
Object.keys(fieldsByMode).forEach((mode) => {
|
||||
if (mode !== authMode) {
|
||||
fieldsByMode[mode].forEach((field) => {
|
||||
delete cleanedValues[field];
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
const data = {
|
||||
...values,
|
||||
...cleanedValues,
|
||||
llm_factory: llmFactory,
|
||||
max_tokens: values.max_tokens,
|
||||
};
|
||||
|
||||
onOk?.(data);
|
||||
onOk?.(data as unknown as IAddLlmRequestBody);
|
||||
};
|
||||
|
||||
return (
|
||||
@ -47,9 +78,6 @@ const BedrockModal = ({
|
||||
title={
|
||||
<div>
|
||||
<LLMHeader name={llmFactory} />
|
||||
<Text type="secondary" style={{ display: 'block', marginTop: 4 }}>
|
||||
{t('bedrockCredentialsHint')}
|
||||
</Text>
|
||||
</div>
|
||||
}
|
||||
open={visible}
|
||||
@ -82,20 +110,75 @@ const BedrockModal = ({
|
||||
>
|
||||
<Input placeholder={t('bedrockModelNameMessage')} />
|
||||
</Form.Item>
|
||||
<Form.Item<FieldType>
|
||||
label={t('addBedrockEngineAK')}
|
||||
name="bedrock_ak"
|
||||
rules={[{ message: t('bedrockAKMessage') }]}
|
||||
>
|
||||
<Input placeholder={t('bedrockAKMessage')} />
|
||||
</Form.Item>
|
||||
<Form.Item<FieldType>
|
||||
label={t('addBedrockSK')}
|
||||
name="bedrock_sk"
|
||||
rules={[{ message: t('bedrockSKMessage') }]}
|
||||
>
|
||||
<Input placeholder={t('bedrockSKMessage')} />
|
||||
|
||||
{/* AWS Credential Mode Switch (AK/SK section only) */}
|
||||
<Form.Item>
|
||||
<Segmented
|
||||
block
|
||||
value={authMode}
|
||||
onChange={(v) => {
|
||||
const next = v as FieldType['auth_mode'];
|
||||
setAuthMode(next);
|
||||
// Clear non-active fields so they won't be validated/submitted by accident.
|
||||
if (next !== 'access_key_secret') {
|
||||
form.setFieldsValue({ bedrock_ak: '', bedrock_sk: '' } as any);
|
||||
}
|
||||
if (next !== 'iam_role') {
|
||||
form.setFieldsValue({ aws_role_arn: '' } as any);
|
||||
}
|
||||
if (next !== 'assume_role') {
|
||||
form.setFieldsValue({ role_arn: '' } as any);
|
||||
}
|
||||
}}
|
||||
options={[
|
||||
{
|
||||
label: t('awsAuthModeAccessKeySecret'),
|
||||
value: 'access_key_secret',
|
||||
},
|
||||
{ label: t('awsAuthModeIamRole'), value: 'iam_role' },
|
||||
{ label: t('awsAuthModeAssumeRole'), value: 'assume_role' },
|
||||
]}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
{authMode === 'access_key_secret' && (
|
||||
<>
|
||||
<Form.Item<FieldType>
|
||||
label={t('awsAccessKeyId')}
|
||||
name="bedrock_ak"
|
||||
rules={[{ required: true, message: t('bedrockAKMessage') }]}
|
||||
>
|
||||
<Input placeholder={t('bedrockAKMessage')} />
|
||||
</Form.Item>
|
||||
<Form.Item<FieldType>
|
||||
label={t('awsSecretAccessKey')}
|
||||
name="bedrock_sk"
|
||||
rules={[{ required: true, message: t('bedrockSKMessage') }]}
|
||||
>
|
||||
<Input placeholder={t('bedrockSKMessage')} />
|
||||
</Form.Item>
|
||||
</>
|
||||
)}
|
||||
|
||||
{authMode === 'iam_role' && (
|
||||
<Form.Item<FieldType>
|
||||
label={t('awsRoleArn')}
|
||||
name="aws_role_arn"
|
||||
rules={[{ required: true, message: t('awsRoleArnMessage') }]}
|
||||
>
|
||||
<Input placeholder={t('awsRoleArnMessage')} />
|
||||
</Form.Item>
|
||||
)}
|
||||
|
||||
{authMode === 'assume_role' && (
|
||||
<Form.Item
|
||||
style={{ marginTop: -8, marginBottom: 16 }}
|
||||
// keep layout consistent with other modes
|
||||
>
|
||||
<Text type="secondary">{t('awsAssumeRoleTip')}</Text>
|
||||
</Form.Item>
|
||||
)}
|
||||
|
||||
<Form.Item<FieldType>
|
||||
label={t('bedrockRegion')}
|
||||
name="bedrock_region"
|
||||
|
||||
@ -16,7 +16,7 @@ import { IModalProps } from '@/interfaces/common';
|
||||
import { buildOptions } from '@/utils/form';
|
||||
import { zodResolver } from '@hookform/resolvers/zod';
|
||||
import { t } from 'i18next';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { useForm, useWatch } from 'react-hook-form';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { z } from 'zod';
|
||||
import { LLMHeader } from '../../components/llm-header';
|
||||
@ -25,15 +25,18 @@ const FormSchema = z.object({
|
||||
llm_name: z.string().min(1, {
|
||||
message: t('setting.mineru.modelNameRequired'),
|
||||
}),
|
||||
mineru_apiserver: z.string().optional(),
|
||||
mineru_apiserver: z.string().url(),
|
||||
mineru_output_dir: z.string().optional(),
|
||||
mineru_backend: z.enum([
|
||||
'pipeline',
|
||||
'vlm-transformers',
|
||||
'vlm-vllm-engine',
|
||||
'vlm-http-client',
|
||||
'vlm-mlx-engine',
|
||||
'vlm-vllm-async-engine',
|
||||
'vlm-lmdeploy-engine',
|
||||
]),
|
||||
mineru_server_url: z.string().optional(),
|
||||
mineru_server_url: z.string().url().optional(),
|
||||
mineru_delete_output: z.boolean(),
|
||||
});
|
||||
|
||||
@ -52,6 +55,9 @@ const MinerUModal = ({
|
||||
'vlm-transformers',
|
||||
'vlm-vllm-engine',
|
||||
'vlm-http-client',
|
||||
'vlm-mlx-engine',
|
||||
'vlm-vllm-async-engine',
|
||||
'vlm-lmdeploy-engine',
|
||||
]);
|
||||
|
||||
const form = useForm<MinerUFormValues>({
|
||||
@ -62,6 +68,11 @@ const MinerUModal = ({
|
||||
},
|
||||
});
|
||||
|
||||
const backend = useWatch({
|
||||
control: form.control,
|
||||
name: 'mineru_backend',
|
||||
});
|
||||
|
||||
const handleOk = async (values: MinerUFormValues) => {
|
||||
const ret = await onOk?.(values as any);
|
||||
if (ret) {
|
||||
@ -93,6 +104,7 @@ const MinerUModal = ({
|
||||
<RAGFlowFormItem
|
||||
name="mineru_apiserver"
|
||||
label={t('setting.mineru.apiserver')}
|
||||
required
|
||||
>
|
||||
<Input placeholder="http://host.docker.internal:9987" />
|
||||
</RAGFlowFormItem>
|
||||
@ -109,18 +121,25 @@ const MinerUModal = ({
|
||||
{(field) => (
|
||||
<RAGFlowSelect
|
||||
value={field.value}
|
||||
onChange={field.onChange}
|
||||
onChange={(value) => {
|
||||
field.onChange(value);
|
||||
if (value !== 'vlm-http-client') {
|
||||
form.setValue('mineru_server_url', undefined);
|
||||
}
|
||||
}}
|
||||
options={backendOptions}
|
||||
placeholder={t('setting.mineru.selectBackend')}
|
||||
/>
|
||||
)}
|
||||
</RAGFlowFormItem>
|
||||
<RAGFlowFormItem
|
||||
name="mineru_server_url"
|
||||
label={t('setting.mineru.serverUrl')}
|
||||
>
|
||||
<Input placeholder="http://your-vllm-server:30000" />
|
||||
</RAGFlowFormItem>
|
||||
{backend === 'vlm-http-client' && (
|
||||
<RAGFlowFormItem
|
||||
name="mineru_server_url"
|
||||
label={t('setting.mineru.serverUrl')}
|
||||
>
|
||||
<Input placeholder="http://your-vllm-server:30000" />
|
||||
</RAGFlowFormItem>
|
||||
)}
|
||||
<RAGFlowFormItem
|
||||
name="mineru_delete_output"
|
||||
label={t('setting.mineru.deleteOutput')}
|
||||
|
||||
Reference in New Issue
Block a user