Compare commits

...

4 Commits

Author SHA1 Message Date
da5cef0686 Refactor:Improve the float compare for LocalAIRerank (#9428)
### What problem does this PR solve?
Improve the float compare for LocalAIRerank

### Type of change

- [x] Refactoring
2025-08-13 10:26:42 +08:00
9098efb8aa Feat: Fixed the issue where some fields in the chat configuration could not be displayed #3221 (#9430)
### What problem does this PR solve?

Feat: Fixed the issue where some fields in the chat configuration could
not be displayed #3221
### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-08-13 10:26:26 +08:00
421657f64b Feat: allows setting multiple types of default models in service config (#9404)
### What problem does this PR solve?

Allows set multiple types of default models in service config.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
2025-08-13 09:46:05 +08:00
7ee5e0d152 Fix KeyError in session listing endpoint when accessing conversation reference (#9419)
- Add type and boundary checks for conv["reference"] access
- Prevent KeyError: 0 when reference list is empty or malformed
- Ensure reference is list type before indexing
- Handle cases where reference items are None or missing chunks
- Maintains backward compatibility with existing data structures

This resolves crashes in /api/v1/agents/<agent_id>/sessions endpoint
when conversation reference data is not properly structured.

### What problem does this PR solve?

This PR fixes a critical `KeyError: 0` that occurs in the
`/api/v1/agents/<agent_id>/sessions` endpoint when the system attempts
to access conversation reference data that is not properly structured.

**Background Context:**
The `list_agent_session` method in `api/apps/sdk/session.py` assumes
that `conv["reference"]` is always a properly indexed list with valid
dictionary structures. However, in real-world scenarios, this data can
be:
- Not a list type (could be None, string, or other types)
- An empty list when `chunk_num` tries to access index 0
- Contains None values or malformed dictionary structures
- Missing expected "chunks" keys in reference items

**Impact Before Fix:**
When malformed reference data is encountered, the API crashes with:
```json
{
    "code": 100,
    "data": null,
    "message": "KeyError(0)"
}
```
**Solution:**
Added comprehensive safety checks including type validation, boundary
checking, null safety, and structure validation to ensure the API
gracefully handles all reference data formats while maintaining backward
compatibility.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [ ] New Feature (non-breaking change which adds functionality)
- [ ] Documentation Update
- [ ] Refactoring
- [ ] Performance Improvement
- [ ] Other (please describe):
2025-08-13 09:23:52 +08:00
22 changed files with 391 additions and 157 deletions

View File

@ -589,14 +589,22 @@ def list_agent_session(tenant_id, agent_id):
if "prompt" in info:
info.pop("prompt")
conv["agent_id"] = conv.pop("dialog_id")
# Fix for session listing endpoint
if conv["reference"]:
messages = conv["messages"]
message_num = 0
chunk_num = 0
# Ensure reference is a list type to prevent KeyError
if not isinstance(conv["reference"], list):
conv["reference"] = []
while message_num < len(messages):
if message_num != 0 and messages[message_num]["role"] != "user":
chunk_list = []
if "chunks" in conv["reference"][chunk_num]:
# Add boundary and type checks to prevent KeyError
if (chunk_num < len(conv["reference"]) and
conv["reference"][chunk_num] is not None and
isinstance(conv["reference"][chunk_num], dict) and
"chunks" in conv["reference"][chunk_num]):
chunks = conv["reference"][chunk_num]["chunks"]
for chunk in chunks:
new_chunk = {

View File

@ -620,18 +620,35 @@ def user_register(user_id, user):
"location": "",
}
tenant_llm = []
for llm in LLMService.query(fid=settings.LLM_FACTORY):
tenant_llm.append(
{
"tenant_id": user_id,
"llm_factory": settings.LLM_FACTORY,
"llm_name": llm.llm_name,
"model_type": llm.model_type,
"api_key": settings.API_KEY,
"api_base": settings.LLM_BASE_URL,
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
}
)
seen = set()
factory_configs = []
for factory_config in [
settings.CHAT_CFG,
settings.EMBEDDING_CFG,
settings.ASR_CFG,
settings.IMAGE2TEXT_CFG,
settings.RERANK_CFG,
]:
factory_name = factory_config["factory"]
if factory_name not in seen:
seen.add(factory_name)
factory_configs.append(factory_config)
for factory_config in factory_configs:
for llm in LLMService.query(fid=factory_config["factory"]):
tenant_llm.append(
{
"tenant_id": user_id,
"llm_factory": factory_config["factory"],
"llm_name": llm.llm_name,
"model_type": llm.model_type,
"api_key": factory_config["api_key"],
"api_base": factory_config["base_url"],
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
}
)
if settings.LIGHTEN != 1:
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
@ -647,6 +664,13 @@ def user_register(user_id, user):
}
)
unique = {}
for item in tenant_llm:
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
if key not in unique:
unique[key] = item
tenant_llm = list(unique.values())
if not UserService.save(**user):
return
TenantService.insert(**tenant)

View File

@ -63,12 +63,44 @@ def init_superuser():
"invited_by": user_info["id"],
"role": UserTenantRole.OWNER
}
user_id = user_info
tenant_llm = []
for llm in LLMService.query(fid=settings.LLM_FACTORY):
tenant_llm.append(
{"tenant_id": user_info["id"], "llm_factory": settings.LLM_FACTORY, "llm_name": llm.llm_name,
"model_type": llm.model_type,
"api_key": settings.API_KEY, "api_base": settings.LLM_BASE_URL})
seen = set()
factory_configs = []
for factory_config in [
settings.CHAT_CFG["factory"],
settings.EMBEDDING_CFG["factory"],
settings.ASR_CFG["factory"],
settings.IMAGE2TEXT_CFG["factory"],
settings.RERANK_CFG["factory"],
]:
factory_name = factory_config["factory"]
if factory_name not in seen:
seen.add(factory_name)
factory_configs.append(factory_config)
for factory_config in factory_configs:
for llm in LLMService.query(fid=factory_config["factory"]):
tenant_llm.append(
{
"tenant_id": user_id,
"llm_factory": factory_config["factory"],
"llm_name": llm.llm_name,
"model_type": llm.model_type,
"api_key": factory_config["api_key"],
"api_base": factory_config["base_url"],
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
}
)
unique = {}
for item in tenant_llm:
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
if key not in unique:
unique[key] = item
tenant_llm = list(unique.values())
if not UserService.save(**user_info):
logging.error("can't init admin.")
@ -103,7 +135,7 @@ def init_llm_factory():
except Exception:
pass
factory_llm_infos = settings.FACTORY_LLM_INFOS
factory_llm_infos = settings.FACTORY_LLM_INFOS
for factory_llm_info in factory_llm_infos:
info = deepcopy(factory_llm_info)
llm_infos = info.pop("llm")

View File

@ -121,6 +121,7 @@ class DialogService(CommonService):
cls.model.do_refer,
cls.model.rerank_id,
cls.model.kb_ids,
cls.model.icon,
cls.model.status,
User.nickname,
User.avatar.alias("tenant_avatar"),

View File

@ -38,6 +38,11 @@ EMBEDDING_MDL = ""
RERANK_MDL = ""
ASR_MDL = ""
IMAGE2TEXT_MDL = ""
CHAT_CFG = ""
EMBEDDING_CFG = ""
RERANK_CFG = ""
ASR_CFG = ""
IMAGE2TEXT_CFG = ""
API_KEY = None
PARSERS = None
HOST_IP = None
@ -74,23 +79,22 @@ STRONG_TEST_COUNT = int(os.environ.get("STRONG_TEST_COUNT", "8"))
BUILTIN_EMBEDDING_MODELS = ["BAAI/bge-large-zh-v1.5@BAAI", "maidalun1020/bce-embedding-base_v1@Youdao"]
def get_or_create_secret_key():
secret_key = os.environ.get("RAGFLOW_SECRET_KEY")
if secret_key and len(secret_key) >= 32:
return secret_key
# Check if there's a configured secret key
configured_key = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("secret_key")
if configured_key and configured_key != str(date.today()) and len(configured_key) >= 32:
return configured_key
# Generate a new secure key and warn about it
import logging
new_key = secrets.token_hex(32)
logging.warning(
"SECURITY WARNING: Using auto-generated SECRET_KEY. "
f"Generated key: {new_key}"
)
logging.warning(f"SECURITY WARNING: Using auto-generated SECRET_KEY. Generated key: {new_key}")
return new_key
@ -99,10 +103,10 @@ def init_settings():
LIGHTEN = int(os.environ.get("LIGHTEN", "0"))
DATABASE_TYPE = os.getenv("DB_TYPE", "mysql")
DATABASE = decrypt_database_config(name=DATABASE_TYPE)
LLM = get_base_config("user_default_llm", {})
LLM_DEFAULT_MODELS = LLM.get("default_models", {})
LLM_FACTORY = LLM.get("factory")
LLM_BASE_URL = LLM.get("base_url")
LLM = get_base_config("user_default_llm", {}) or {}
LLM_DEFAULT_MODELS = LLM.get("default_models", {}) or {}
LLM_FACTORY = LLM.get("factory", "") or ""
LLM_BASE_URL = LLM.get("base_url", "") or ""
try:
REGISTER_ENABLED = int(os.environ.get("REGISTER_ENABLED", "1"))
except Exception:
@ -115,29 +119,34 @@ def init_settings():
FACTORY_LLM_INFOS = []
global CHAT_MDL, EMBEDDING_MDL, RERANK_MDL, ASR_MDL, IMAGE2TEXT_MDL
global CHAT_CFG, EMBEDDING_CFG, RERANK_CFG, ASR_CFG, IMAGE2TEXT_CFG
if not LIGHTEN:
EMBEDDING_MDL = BUILTIN_EMBEDDING_MODELS[0]
if LLM_DEFAULT_MODELS:
CHAT_MDL = LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL)
EMBEDDING_MDL = LLM_DEFAULT_MODELS.get("embedding_model", EMBEDDING_MDL)
RERANK_MDL = LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL)
ASR_MDL = LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL)
IMAGE2TEXT_MDL = LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL)
# factory can be specified in the config name with "@". LLM_FACTORY will be used if not specified
CHAT_MDL = CHAT_MDL + (f"@{LLM_FACTORY}" if "@" not in CHAT_MDL and CHAT_MDL != "" else "")
EMBEDDING_MDL = EMBEDDING_MDL + (f"@{LLM_FACTORY}" if "@" not in EMBEDDING_MDL and EMBEDDING_MDL != "" else "")
RERANK_MDL = RERANK_MDL + (f"@{LLM_FACTORY}" if "@" not in RERANK_MDL and RERANK_MDL != "" else "")
ASR_MDL = ASR_MDL + (f"@{LLM_FACTORY}" if "@" not in ASR_MDL and ASR_MDL != "" else "")
IMAGE2TEXT_MDL = IMAGE2TEXT_MDL + (f"@{LLM_FACTORY}" if "@" not in IMAGE2TEXT_MDL and IMAGE2TEXT_MDL != "" else "")
global API_KEY, PARSERS, HOST_IP, HOST_PORT, SECRET_KEY
API_KEY = LLM.get("api_key")
PARSERS = LLM.get(
"parsers", "naive:General,qa:Q&A,resume:Resume,manual:Manual,table:Table,paper:Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag"
)
chat_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("chat_model", CHAT_MDL))
embedding_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("embedding_model", EMBEDDING_MDL))
rerank_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("rerank_model", RERANK_MDL))
asr_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("asr_model", ASR_MDL))
image2text_entry = _parse_model_entry(LLM_DEFAULT_MODELS.get("image2text_model", IMAGE2TEXT_MDL))
CHAT_CFG = _resolve_per_model_config(chat_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
EMBEDDING_CFG = _resolve_per_model_config(embedding_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
RERANK_CFG = _resolve_per_model_config(rerank_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
ASR_CFG = _resolve_per_model_config(asr_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
IMAGE2TEXT_CFG = _resolve_per_model_config(image2text_entry, LLM_FACTORY, API_KEY, LLM_BASE_URL)
CHAT_MDL = CHAT_CFG.get("model", "") or ""
EMBEDDING_MDL = EMBEDDING_CFG.get("model", "") or ""
RERANK_MDL = RERANK_CFG.get("model", "") or ""
ASR_MDL = ASR_CFG.get("model", "") or ""
IMAGE2TEXT_MDL = IMAGE2TEXT_CFG.get("model", "") or ""
HOST_IP = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("host", "127.0.0.1")
HOST_PORT = get_base_config(RAG_FLOW_SERVICE_NAME, {}).get("http_port")
@ -170,6 +179,7 @@ def init_settings():
retrievaler = search.Dealer(docStoreConn)
from graphrag import search as kg_search
kg_retrievaler = kg_search.KGSearch(docStoreConn)
if int(os.environ.get("SANDBOX_ENABLED", "0")):
@ -210,3 +220,34 @@ class RetCode(IntEnum, CustomEnum):
SERVER_ERROR = 500
FORBIDDEN = 403
NOT_FOUND = 404
def _parse_model_entry(entry):
if isinstance(entry, str):
return {"name": entry, "factory": None, "api_key": None, "base_url": None}
if isinstance(entry, dict):
name = entry.get("name") or entry.get("model") or ""
return {
"name": name,
"factory": entry.get("factory"),
"api_key": entry.get("api_key"),
"base_url": entry.get("base_url"),
}
return {"name": "", "factory": None, "api_key": None, "base_url": None}
def _resolve_per_model_config(entry_dict, backup_factory, backup_api_key, backup_base_url):
name = (entry_dict.get("name") or "").strip()
m_factory = entry_dict.get("factory") or backup_factory or ""
m_api_key = entry_dict.get("api_key") or backup_api_key or ""
m_base_url = entry_dict.get("base_url") or backup_base_url or ""
if name and "@" not in name and m_factory:
name = f"{name}@{m_factory}"
return {
"model": name,
"factory": m_factory,
"api_key": m_api_key,
"base_url": m_base_url,
}

View File

@ -64,9 +64,21 @@ redis:
# config:
# oss_table: 'opendal_storage'
# user_default_llm:
# factory: 'Tongyi-Qianwen'
# api_key: 'sk-xxxxxxxxxxxxx'
# base_url: ''
# factory: 'BAAI'
# api_key: 'backup'
# base_url: 'backup_base_url'
# default_models:
# chat_model:
# name: 'qwen2.5-7b-instruct'
# factory: 'xxxx'
# api_key: 'xxxx'
# base_url: 'https://api.xx.com'
# embedding_model:
# name: 'bge-m3'
# rerank_model: 'bge-reranker-v2'
# asr_model:
# model: 'whisper-large-v3' # alias of name
# image2text_model: ''
# oauth:
# oauth2:
# display_name: "OAuth2"

View File

@ -268,7 +268,7 @@ class LocalAIRerank(Base):
max_rank = np.max(rank)
# Avoid division by zero if all ranks are identical
if max_rank - min_rank != 0:
if not np.isclose(min_rank, max_rank, atol=1e-3):
rank = (rank - min_rank) / (max_rank - min_rank)
else:
rank = np.zeros_like(rank)

View File

@ -30,15 +30,16 @@ interface LlmSettingFieldItemsProps {
export const LlmSettingSchema = {
llm_id: z.string(),
temperature: z.coerce.number(),
top_p: z.string(),
presence_penalty: z.coerce.number(),
frequency_penalty: z.coerce.number(),
temperatureEnabled: z.boolean(),
topPEnabled: z.boolean(),
presencePenaltyEnabled: z.boolean(),
frequencyPenaltyEnabled: z.boolean(),
maxTokensEnabled: z.boolean(),
temperature: z.coerce.number().optional(),
top_p: z.number().optional(),
presence_penalty: z.coerce.number().optional(),
frequency_penalty: z.coerce.number().optional(),
temperatureEnabled: z.boolean().optional(),
topPEnabled: z.boolean().optional(),
presencePenaltyEnabled: z.boolean().optional(),
frequencyPenaltyEnabled: z.boolean().optional(),
maxTokensEnabled: z.boolean().optional(),
max_tokens: z.number().optional(),
};
export function LlmSettingFieldItems({
@ -53,13 +54,6 @@ export function LlmSettingFieldItems({
LlmModelType.Image2text,
]);
const handleChange = useHandleFreedomChange();
const parameterOptions = Object.values(ModelVariableType).map((x) => ({
label: t(camelCase(x)),
value: x,
}));
const getFieldWithPrefix = useCallback(
(name: string) => {
return prefix ? `${prefix}.${name}` : name;
@ -67,6 +61,13 @@ export function LlmSettingFieldItems({
[prefix],
);
const handleChange = useHandleFreedomChange(getFieldWithPrefix);
const parameterOptions = Object.values(ModelVariableType).map((x) => ({
label: t(camelCase(x)),
value: x,
}));
return (
<div className="space-y-5">
<FormField
@ -77,6 +78,7 @@ export function LlmSettingFieldItems({
<FormLabel>{t('model')}</FormLabel>
<FormControl>
<SelectWithSearch
allowClear
options={options || modelOptions}
{...field}
></SelectWithSearch>

View File

@ -9,7 +9,7 @@ import {
FormLabel,
FormMessage,
} from '../ui/form';
import { Input } from '../ui/input';
import { NumberInput } from '../ui/input';
import { Switch } from '../ui/switch';
type SliderInputSwitchFormFieldProps = {
@ -73,15 +73,14 @@ export function SliderInputSwitchFormField({
></SingleFormSlider>
</FormControl>
<FormControl>
<Input
<NumberInput
disabled={disabled}
type={'number'}
className="h-7 w-20"
max={max}
min={min}
step={step}
{...field}
></Input>
></NumberInput>
</FormControl>
</div>
<FormMessage />

View File

@ -4,7 +4,9 @@ import useGraphStore from '@/pages/agent/store';
import { useCallback, useContext } from 'react';
import { useFormContext } from 'react-hook-form';
export function useHandleFreedomChange() {
export function useHandleFreedomChange(
getFieldWithPrefix: (name: string) => string,
) {
const form = useFormContext();
const node = useContext(AgentFormContext);
const updateNodeForm = useGraphStore((state) => state.updateNodeForm);
@ -25,13 +27,14 @@ export function useHandleFreedomChange() {
for (const key in values) {
if (Object.prototype.hasOwnProperty.call(values, key)) {
const element = values[key];
const realKey = getFieldWithPrefix(key);
const element = values[key as keyof typeof values];
form.setValue(key, element);
form.setValue(realKey, element);
}
}
},
[form, node, updateNodeForm],
[form, getFieldWithPrefix, node?.id, updateNodeForm],
);
return handleChange;

View File

@ -5,6 +5,7 @@ import { Select as AntSelect, Form, message, Slider } from 'antd';
import { useCallback } from 'react';
import { useFormContext } from 'react-hook-form';
import { z } from 'zod';
import { SelectWithSearch } from './originui/select-with-search';
import { SliderInputFormField } from './slider-input-form-field';
import {
FormControl,
@ -13,7 +14,6 @@ import {
FormLabel,
FormMessage,
} from './ui/form';
import { RAGFlowSelect } from './ui/select';
type FieldType = {
rerank_id?: string;
@ -109,11 +109,11 @@ function RerankFormField() {
<FormItem>
<FormLabel tooltip={t('rerankTip')}>{t('rerankModel')}</FormLabel>
<FormControl>
<RAGFlowSelect
<SelectWithSearch
allowClear
{...field}
options={options}
></RAGFlowSelect>
></SelectWithSearch>
</FormControl>
<FormMessage />
</FormItem>
@ -122,6 +122,11 @@ function RerankFormField() {
);
}
export const rerankFormSchema = {
[RerankId]: z.string().optional(),
top_k: z.coerce.number().optional(),
};
export function RerankFormFields() {
const { watch } = useFormContext();
const { t } = useTranslate('knowledgeDetails');

View File

@ -61,11 +61,20 @@ export const keywordsSimilarityWeightSchema = {
keywords_similarity_weight: z.number(),
};
export const vectorSimilarityWeightSchema = {
vector_similarity_weight: z.number(),
};
export const initialVectorSimilarityWeightValue = {
vector_similarity_weight: 0.3,
};
export function SimilaritySliderFormField({
vectorSimilarityWeightName = 'vector_similarity_weight',
isTooltipShown,
}: SimilaritySliderFormFieldProps) {
const { t } = useTranslate('knowledgeDetails');
const isVector = vectorSimilarityWeightName === 'vector_similarity_weight';
return (
<>
@ -78,10 +87,19 @@ export function SimilaritySliderFormField({
></SliderInputFormField>
<SliderInputFormField
name={vectorSimilarityWeightName}
label={t('vectorSimilarityWeight')}
label={t(
isVector ? 'vectorSimilarityWeight' : 'keywordSimilarityWeight',
)}
max={1}
step={0.01}
tooltip={isTooltipShown && t('vectorSimilarityWeightTip')}
tooltip={
isTooltipShown &&
t(
isVector
? 'vectorSimilarityWeightTip'
: 'keywordSimilarityWeightTip',
)
}
></SliderInputFormField>
</>
);

View File

@ -10,7 +10,7 @@ import {
FormLabel,
FormMessage,
} from './ui/form';
import { Input } from './ui/input';
import { NumberInput } from './ui/input';
export type FormLayoutType = {
layout?: FormLayout;
@ -79,19 +79,14 @@ export function SliderInputFormField({
></SingleFormSlider>
</FormControl>
<FormControl>
<Input
type={'number'}
<NumberInput
className="h-7 w-20"
max={max}
min={min}
step={step}
{...field}
onChange={(ev) => {
const value = ev.target.value;
field.onChange(value === '' ? 0 : Number(value)); // convert to number
}}
// defaultValue={defaultValue}
></Input>
></NumberInput>
</FormControl>
</div>
<FormMessage />

View File

@ -1,5 +1,6 @@
import { useTranslate } from '@/hooks/common-hooks';
import { Form, Slider } from 'antd';
import { z } from 'zod';
import { SliderInputFormField } from './slider-input-form-field';
type FieldType = {
@ -32,6 +33,10 @@ interface SimilaritySliderFormFieldProps {
max?: number;
}
export const topnSchema = {
top_n: z.number().optional(),
};
export function TopNFormField({ max = 30 }: SimilaritySliderFormFieldProps) {
const { t } = useTranslate('chat');

View File

@ -124,9 +124,11 @@ export default {
similarityThreshold: 'Similarity threshold',
similarityThresholdTip:
'RAGFlow employs either a combination of weighted keyword similarity and weighted vector cosine similarity, or a combination of weighted keyword similarity and weighted reranking score during retrieval. This parameter sets the threshold for similarities between the user query and chunks. Any chunk with a similarity score below this threshold will be excluded from the results. By default, the threshold is set to 0.2. This means that only chunks with hybrid similarity score of 20 or higher will be retrieved.',
vectorSimilarityWeight: 'Keyword similarity weight',
vectorSimilarityWeight: 'Vector similarity weight',
vectorSimilarityWeightTip:
'This sets the weight of keyword similarity in the combined similarity score, either used with vector cosine similarity or with reranking score. The total of the two weights must equal 1.0.',
keywordSimilarityWeight: 'Keyword similarity weight',
keywordSimilarityWeightTip: '',
testText: 'Test text',
testTextPlaceholder: 'Input your question here!',
testingLabel: 'Testing',

View File

@ -115,7 +115,7 @@ export default {
similarityThreshold: '相似度阈值',
similarityThresholdTip:
'我们使用混合相似度得分来评估两行文本之间的距离。 它是加权关键词相似度和向量余弦相似度。 如果查询和块之间的相似度小于此阈值,则该块将被过滤掉。默认设置为 0.2,也就是说文本块的混合相似度得分至少 20 才会被召回。',
vectorSimilarityWeight: '关键字相似度权重',
vectorSimilarityWeight: '相似度相似度权重',
vectorSimilarityWeightTip:
'我们使用混合相似性评分来评估两行文本之间的距离。它是加权关键字相似性和矢量余弦相似性或rerank得分0〜1。两个权重的总和为1.0。',
testText: '测试文本',

View File

@ -1,38 +1,8 @@
import { LlmSettingFieldItems } from '@/components/llm-setting-items/next';
import {
FormControl,
FormField,
FormItem,
FormLabel,
FormMessage,
} from '@/components/ui/form';
import { Textarea } from '@/components/ui/textarea';
import { useTranslate } from '@/hooks/common-hooks';
import { useFormContext } from 'react-hook-form';
export function ChatModelSettings() {
const { t } = useTranslate('chat');
const form = useFormContext();
return (
<div className="space-y-8">
<FormField
control={form.control}
name="prompt_config.system"
render={({ field }) => (
<FormItem>
<FormLabel>{t('system')}</FormLabel>
<FormControl>
<Textarea
placeholder="Tell us a little bit about yourself"
className="resize-none"
{...field}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<LlmSettingFieldItems prefix="llm_setting"></LlmSettingFieldItems>
</div>
);

View File

@ -15,6 +15,7 @@ import { Textarea } from '@/components/ui/textarea';
import { UseKnowledgeGraphFormField } from '@/components/use-knowledge-graph-item';
import { useTranslate } from '@/hooks/common-hooks';
import { useFormContext } from 'react-hook-form';
import { DynamicVariableForm } from './dynamic-variable';
export function ChatPromptEngine() {
const { t } = useTranslate('chat');
@ -29,11 +30,7 @@ export function ChatPromptEngine() {
<FormItem>
<FormLabel>{t('system')}</FormLabel>
<FormControl>
<Textarea
placeholder="Tell us a little bit about yourself"
className="resize-none"
{...field}
/>
<Textarea {...field} />
</FormControl>
<FormMessage />
</FormItem>
@ -47,6 +44,7 @@ export function ChatPromptEngine() {
></SwitchFormField>
<UseKnowledgeGraphFormField name="prompt_config.use_kg"></UseKnowledgeGraphFormField>
<RerankFormFields></RerankFormFields>
<DynamicVariableForm></DynamicVariableForm>
</div>
);
}

View File

@ -1,11 +1,13 @@
import { Button } from '@/components/ui/button';
import { ButtonLoading } from '@/components/ui/button';
import { Form } from '@/components/ui/form';
import { Separator } from '@/components/ui/separator';
import { useFetchDialog } from '@/hooks/use-chat-request';
import { transformBase64ToFile } from '@/utils/file-util';
import { useFetchDialog, useSetDialog } from '@/hooks/use-chat-request';
import { transformBase64ToFile, transformFile2Base64 } from '@/utils/file-util';
import { zodResolver } from '@hookform/resolvers/zod';
import { PanelRightClose } from 'lucide-react';
import { X } from 'lucide-react';
import { useEffect } from 'react';
import { FormProvider, useForm } from 'react-hook-form';
import { useForm } from 'react-hook-form';
import { useParams } from 'umi';
import { z } from 'zod';
import ChatBasicSetting from './chat-basic-settings';
import { ChatModelSettings } from './chat-model-settings';
@ -16,8 +18,12 @@ type ChatSettingsProps = { switchSettingVisible(): void };
export function ChatSettings({ switchSettingVisible }: ChatSettingsProps) {
const formSchema = useChatSettingSchema();
const { data } = useFetchDialog();
const { setDialog, loading } = useSetDialog();
const { id } = useParams();
const form = useForm<z.infer<typeof formSchema>>({
type FormSchemaType = z.infer<typeof formSchema>;
const form = useForm<FormSchemaType>({
resolver: zodResolver(formSchema),
defaultValues: {
name: '',
@ -35,8 +41,22 @@ export function ChatSettings({ switchSettingVisible }: ChatSettingsProps) {
},
});
function onSubmit(values: z.infer<typeof formSchema>) {
console.log(values);
async function onSubmit(values: FormSchemaType) {
const icon = values.icon;
const avatar =
Array.isArray(icon) && icon.length > 0
? await transformFile2Base64(icon[0])
: '';
setDialog({
...data,
...values,
icon: avatar,
dialog_id: id,
});
}
function onInvalid(errors: any) {
console.log('Form validation failed:', errors);
}
useEffect(() => {
@ -44,32 +64,33 @@ export function ChatSettings({ switchSettingVisible }: ChatSettingsProps) {
...data,
icon: data.icon ? [transformBase64ToFile(data.icon)] : [],
};
form.reset(nextData as z.infer<typeof formSchema>);
form.reset(nextData as FormSchemaType);
}, [data, form]);
return (
<section className="p-5 w-[400px] max-w-[20%]">
<section className="p-5 w-[440px] ">
<div className="flex justify-between items-center text-base">
Chat Settings
<PanelRightClose
className="size-4 cursor-pointer"
onClick={switchSettingVisible}
/>
<X className="size-4 cursor-pointer" onClick={switchSettingVisible} />
</div>
<FormProvider {...form}>
<form
onSubmit={form.handleSubmit(onSubmit)}
className="space-y-6 overflow-auto max-h-[87vh] pr-4"
>
<ChatBasicSetting></ChatBasicSetting>
<Separator />
<ChatPromptEngine></ChatPromptEngine>
<Separator />
<ChatModelSettings></ChatModelSettings>
<Form {...form}>
<form onSubmit={form.handleSubmit(onSubmit, onInvalid)}>
<section className="space-y-6 overflow-auto max-h-[87vh] pr-4">
<ChatBasicSetting></ChatBasicSetting>
<Separator />
<ChatPromptEngine></ChatPromptEngine>
<Separator />
<ChatModelSettings></ChatModelSettings>
</section>
<ButtonLoading
className="w-full my-4"
type="submit"
loading={loading}
>
Update
</ButtonLoading>
</form>
</FormProvider>
<Button className="w-full my-4">Update</Button>
</Form>
</section>
);
}

View File

@ -0,0 +1,89 @@
import { Button } from '@/components/ui/button';
import {
FormControl,
FormField,
FormItem,
FormLabel,
FormMessage,
} from '@/components/ui/form';
import { BlurInput } from '@/components/ui/input';
import { Separator } from '@/components/ui/separator';
import { Switch } from '@/components/ui/switch';
import { Plus, X } from 'lucide-react';
import { useCallback } from 'react';
import { useFieldArray, useFormContext } from 'react-hook-form';
import { useTranslation } from 'react-i18next';
export function DynamicVariableForm() {
const { t } = useTranslation();
const form = useFormContext();
const name = 'prompt_config.parameters';
const { fields, remove, append } = useFieldArray({
name,
control: form.control,
});
const add = useCallback(() => {
append({
key: undefined,
optional: false,
});
}, [append]);
return (
<section className="flex flex-col gap-2">
<div className="flex items-center justify-between">
<FormLabel tooltip={t('chat.variableTip')}>
{t('chat.variable')}
</FormLabel>
<Button variant={'ghost'} type="button" onClick={add}>
<Plus />
</Button>
</div>
<div className="space-y-5">
{fields.map((field, index) => {
const typeField = `${name}.${index}.key`;
return (
<div key={field.id} className="flex w-full items-center gap-2">
<FormField
control={form.control}
name={typeField}
render={({ field }) => (
<FormItem className="flex-1 overflow-hidden">
<FormControl>
<BlurInput
{...field}
placeholder={t('common.pleaseInput')}
></BlurInput>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<Separator className="w-3 text-text-secondary" />
<FormField
control={form.control}
name={`${name}.${index}.optional`}
render={({ field }) => (
<FormItem className="flex-1 overflow-hidden">
<FormControl>
<Switch
checked={field.value}
onCheckedChange={field.onChange}
></Switch>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<Button variant={'ghost'} onClick={() => remove(index)}>
<X className="text-text-sub-title-invert " />
</Button>
</div>
);
})}
</div>
</section>
);
}

View File

@ -1,5 +1,9 @@
import { LlmSettingSchema } from '@/components/llm-setting-items/next';
import { rerankFormSchema } from '@/components/rerank';
import { vectorSimilarityWeightSchema } from '@/components/similarity-slider';
import { topnSchema } from '@/components/top-n-item';
import { useTranslate } from '@/hooks/common-hooks';
import { omit } from 'lodash';
import { z } from 'zod';
export function useChatSettingSchema() {
@ -9,13 +13,17 @@ export function useChatSettingSchema() {
quote: z.boolean(),
keyword: z.boolean(),
tts: z.boolean(),
empty_response: z.string().min(1, {
message: t('emptyResponse'),
}),
prologue: z.string().min(1, {}),
empty_response: z.string().optional(),
prologue: z.string().optional(),
system: z.string().min(1, { message: t('systemMessage') }),
refine_multiturn: z.boolean(),
use_kg: z.boolean(),
parameters: z.array(
z.object({
key: z.string(),
optional: z.boolean(),
}),
),
});
const formSchema = z.object({
@ -29,10 +37,11 @@ export function useChatSettingSchema() {
message: 'Username must be at least 1 characters.',
}),
prompt_config: promptConfigSchema,
top_n: z.number(),
vector_similarity_weight: z.number(),
top_k: z.number(),
llm_setting: z.object(LlmSettingSchema),
...rerankFormSchema,
llm_setting: z.object(omit(LlmSettingSchema, 'llm_id')),
llm_id: z.string().optional(),
...vectorSimilarityWeightSchema,
...topnSchema,
});
return formSchema;

View File

@ -47,7 +47,7 @@ export function Sessions({
}
return (
<section className="p-6 w-[400px] max-w-[20%] flex flex-col">
<section className="p-6 w-[296px] flex flex-col">
<section className="flex items-center text-base justify-between gap-2">
<div className="flex gap-3 items-center min-w-0">
<RAGFlowAvatar