Compare commits

...

7 Commits

Author SHA1 Message Date
5e8cd693a5 Refa: split services about llm. (#9450)
### What problem does this PR solve?

### Type of change

- [x] Refactoring
2025-08-13 16:41:01 +08:00
29f297b850 Fix: update broken create agent session due to v0.20.0 changes (#9445)
### What problem does this PR solve?

 Update broken create agent session due to v0.20.0 changes. #9383


**NOTE: A session ID is no longer required to interact with the agent.**

See: #9241, #9309.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-08-13 16:01:54 +08:00
7235638607 Feat: Show multiple chat boxes #3221 (#9443)
### What problem does this PR solve?

Feat: Show multiple chat boxes #3221

### Type of change


- [x] New Feature (non-breaking change which adds functionality)
2025-08-13 15:59:51 +08:00
00919fd599 Fix typo in issue template (#9444) 2025-08-13 14:27:15 +08:00
43c0792ffd Add issue template for agent scenario feature request (#9437) 2025-08-13 12:50:06 +08:00
4b1b68c5fc Fix: no doc hits after meta data filter. (#9435)
### What problem does this PR solve?


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-08-13 12:43:31 +08:00
3492f54c7a Docs: Update HTTP API reference with new response fields (#9434)
### What problem does this PR solve?

Add `url`, `doc_type`, and `created_at` fields to the API response
example in the documentation.

### Type of change

- [x] Documentation Update
2025-08-13 12:18:39 +08:00
31 changed files with 886 additions and 523 deletions

View File

@ -0,0 +1,46 @@
name: "❤️‍🔥ᴬᴳᴱᴺᵀ Agent scenario request"
description: Propose a agent scenario request for RAGFlow.
title: "[Agent Scenario Request]: "
labels: ["❤️‍🔥ᴬᴳᴱᴺᵀ agent scenario"]
body:
- type: checkboxes
attributes:
label: Self Checks
description: "Please check the following in order to be responded in time :)"
options:
- label: I have searched for existing issues [search for existing issues](https://github.com/infiniflow/ragflow/issues), including closed ones.
required: true
- label: I confirm that I am using English to submit this report ([Language Policy](https://github.com/infiniflow/ragflow/issues/5910)).
required: true
- label: Non-english title submitions will be closed directly ( 非英文标题的提交将会被直接关闭 ) ([Language Policy](https://github.com/infiniflow/ragflow/issues/5910)).
required: true
- label: "Please do not modify this template :) and fill in all the required fields."
required: true
- type: textarea
attributes:
label: Is your feature request related to a scenario?
description: |
A clear and concise description of what the scenario is. Ex. I'm always frustrated when [...]
render: Markdown
validations:
required: false
- type: textarea
attributes:
label: Describe the feature you'd like
description: A clear and concise description of what you want to happen.
validations:
required: true
- type: textarea
attributes:
label: Documentation, adoption, use case
description: If you can, explain some scenarios how users might use this, situations it would be helpful in. Any API designs, mockups, or diagrams are also helpful.
render: Markdown
validations:
required: false
- type: textarea
attributes:
label: Additional information
description: |
Add any other context or screenshots about the feature request here.
validations:
required: false

View File

@ -24,7 +24,8 @@ from typing import Any
import json_repair
from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta
from api.db.services.llm_service import LLMBundle, TenantLLMService
from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.mcp_server_service import MCPServerService
from api.utils.api_utils import timeout
from rag.prompts import message_fit_in

View File

@ -24,7 +24,8 @@ from copy import deepcopy
from functools import partial
from api.db import LLMType
from api.db.services.llm_service import LLMBundle, TenantLLMService
from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
from rag.prompts import message_fit_in, citation_prompt

View File

@ -28,8 +28,8 @@ from api.db.db_models import APIToken
from api.db.services.conversation_service import ConversationService, structure_answer
from api.db.services.dialog_service import DialogService, ask, chat
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle, TenantService
from api.db.services.user_service import UserTenantService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import UserTenantService, TenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.app.tag import label_question

View File

@ -18,7 +18,7 @@ from flask import request
from flask_login import login_required, current_user
from api.db.services.dialog_service import DialogService
from api.db import StatusEnum
from api.db.services.llm_service import TenantLLMService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.user_service import TenantService, UserTenantService
from api import settings

View File

@ -17,7 +17,8 @@ import logging
import json
from flask import request
from flask_login import login_required, current_user
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService
from api import settings
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.db import StatusEnum, LLMType

View File

@ -21,7 +21,7 @@ from api import settings
from api.db import StatusEnum
from api.db.services.dialog_service import DialogService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService
from api.utils import get_uuid
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required

View File

@ -32,7 +32,8 @@ from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle, TenantLLMService
from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.task_service import TaskService, queue_tasks
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
from rag.app.qa import beAdoc, rmPrefix

View File

@ -16,10 +16,8 @@
import json
import re
import time
import tiktoken
from flask import Response, jsonify, request
from agent.canvas import Canvas
from api.db import LLMType, StatusEnum
from api.db.db_models import APIToken
@ -29,7 +27,6 @@ from api.db.services.canvas_service import completion as agent_completion
from api.db.services.conversation_service import ConversationService, iframe_completion
from api.db.services.conversation_service import completion as rag_completion
from api.db.services.dialog_service import DialogService, ask, chat
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.utils import get_uuid
@ -69,11 +66,7 @@ def create(tenant_id, chat_id):
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
@token_required
def create_agent_session(tenant_id, agent_id):
req = request.json
if not request.is_json:
req = request.form
files = request.files
user_id = request.args.get("user_id", "")
user_id = request.args.get("user_id", tenant_id)
e, cvs = UserCanvasService.get_by_id(agent_id)
if not e:
return get_error_data_result("Agent not found.")
@ -82,46 +75,21 @@ def create_agent_session(tenant_id, agent_id):
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
canvas = Canvas(cvs.dsl, tenant_id)
session_id=get_uuid()
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
canvas.reset()
query = canvas.get_preset_param()
if query:
for ele in query:
if not ele["optional"]:
if ele["type"] == "file":
if files is None or not files.get(ele["key"]):
return get_error_data_result(f"`{ele['key']}` with type `{ele['type']}` is required")
upload_file = files.get(ele["key"])
file_content = FileService.parse_docs([upload_file], user_id)
file_name = upload_file.filename
ele["value"] = file_name + "\n" + file_content
else:
if req is None or not req.get(ele["key"]):
return get_error_data_result(f"`{ele['key']}` with type `{ele['type']}` is required")
ele["value"] = req[ele["key"]]
else:
if ele["type"] == "file":
if files is not None and files.get(ele["key"]):
upload_file = files.get(ele["key"])
file_content = FileService.parse_docs([upload_file], user_id)
file_name = upload_file.filename
ele["value"] = file_name + "\n" + file_content
else:
if "value" in ele:
ele.pop("value")
else:
if req is not None and req.get(ele["key"]):
ele["value"] = req[ele["key"]]
else:
if "value" in ele:
ele.pop("value")
for ans in canvas.run(stream=False):
pass
conv = {
"id": session_id,
"dialog_id": cvs.id,
"user_id": user_id,
"message": [],
"source": "agent",
"dsl": cvs.dsl
}
API4ConversationService.save(**conv)
cvs.dsl = json.loads(str(canvas))
conv = {"id": get_uuid(), "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
API4ConversationService.save(**conv)
conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
conv["agent_id"] = conv.pop("dialog_id")
return get_result(data=conv)

View File

@ -28,7 +28,7 @@ from api.apps.auth import get_auth_client
from api.db import FileType, UserTenantRole
from api.db.db_models import TenantLLM
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMService, TenantLLMService
from api.db.services.llm_service import TenantLLMService, get_init_tenant_llm
from api.db.services.user_service import TenantService, UserService, UserTenantService
from api.utils import (
current_timestamp,
@ -619,57 +619,8 @@ def user_register(user_id, user):
"size": 0,
"location": "",
}
tenant_llm = []
seen = set()
factory_configs = []
for factory_config in [
settings.CHAT_CFG,
settings.EMBEDDING_CFG,
settings.ASR_CFG,
settings.IMAGE2TEXT_CFG,
settings.RERANK_CFG,
]:
factory_name = factory_config["factory"]
if factory_name not in seen:
seen.add(factory_name)
factory_configs.append(factory_config)
for factory_config in factory_configs:
for llm in LLMService.query(fid=factory_config["factory"]):
tenant_llm.append(
{
"tenant_id": user_id,
"llm_factory": factory_config["factory"],
"llm_name": llm.llm_name,
"model_type": llm.model_type,
"api_key": factory_config["api_key"],
"api_base": factory_config["base_url"],
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
}
)
if settings.LIGHTEN != 1:
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
tenant_llm.append(
{
"tenant_id": user_id,
"llm_factory": fid,
"llm_name": mdlnm,
"model_type": "embedding",
"api_key": "",
"api_base": "",
"max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512,
}
)
unique = {}
for item in tenant_llm:
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
if key not in unique:
unique[key] = item
tenant_llm = list(unique.values())
tenant_llm = get_init_tenant_llm(user_id)
if not UserService.save(**user):
return

View File

@ -27,7 +27,8 @@ from api.db.services import UserService
from api.db.services.canvas_service import CanvasTemplateService
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
from api.db.services.user_service import TenantService, UserTenantService
from api import settings
from api.utils.file_utils import get_project_base_directory
@ -64,43 +65,7 @@ def init_superuser():
"role": UserTenantRole.OWNER
}
user_id = user_info
tenant_llm = []
seen = set()
factory_configs = []
for factory_config in [
settings.CHAT_CFG["factory"],
settings.EMBEDDING_CFG["factory"],
settings.ASR_CFG["factory"],
settings.IMAGE2TEXT_CFG["factory"],
settings.RERANK_CFG["factory"],
]:
factory_name = factory_config["factory"]
if factory_name not in seen:
seen.add(factory_name)
factory_configs.append(factory_config)
for factory_config in factory_configs:
for llm in LLMService.query(fid=factory_config["factory"]):
tenant_llm.append(
{
"tenant_id": user_id,
"llm_factory": factory_config["factory"],
"llm_name": llm.llm_name,
"model_type": llm.model_type,
"api_key": factory_config["api_key"],
"api_base": factory_config["base_url"],
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
}
)
unique = {}
for item in tenant_llm:
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
if key not in unique:
unique[key] = item
tenant_llm = list(unique.values())
tenant_llm = get_init_tenant_llm(user_info["id"])
if not UserService.save(**user_info):
logging.error("can't init admin.")

View File

@ -33,7 +33,8 @@ from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.llm_service import LLMBundle, TenantLLMService
from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from api.utils import current_timestamp, datetime_format
from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question
@ -365,8 +366,12 @@ def chat(dialog, messages, stream=True, **kwargs):
if dialog.meta_data_filter.get("method") == "auto":
filters = gen_meta_filter(chat_mdl, metas, questions[-1])
attachments.extend(meta_filter(metas, filters))
if not attachments:
attachments = None
elif dialog.meta_data_filter.get("method") == "manual":
attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"]))
if not attachments:
attachments = None
if prompt_config.get("keyword", False):
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
@ -375,17 +380,16 @@ def chat(dialog, messages, stream=True, **kwargs):
thought = ""
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
knowledges = []
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
knowledges = []
else:
if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]:
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
knowledges = []
if prompt_config.get("reasoning", False):
reasoner = DeepResearcher(
chat_mdl,
prompt_config,
partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3),
partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3, doc_ids=attachments),
)
for think in reasoner.thinking(kbinfos, " ".join(questions)):

View File

@ -18,246 +18,73 @@ import logging
import re
from functools import partial
from typing import Generator
from langfuse import Langfuse
from api import settings
from api.db import LLMType
from api.db.db_models import DB, LLM, LLMFactories, TenantLLM
from api.db.db_models import LLM
from api.db.services.common_service import CommonService
from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.user_service import TenantService
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
class LLMFactoriesService(CommonService):
model = LLMFactories
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
class LLMService(CommonService):
model = LLM
class TenantLLMService(CommonService):
model = TenantLLM
def get_init_tenant_llm(user_id):
from api import settings
tenant_llm = []
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_name):
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
if not fid:
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
else:
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
seen = set()
factory_configs = []
for factory_config in [
settings.CHAT_CFG,
settings.EMBEDDING_CFG,
settings.ASR_CFG,
settings.IMAGE2TEXT_CFG,
settings.RERANK_CFG,
]:
factory_name = factory_config["factory"]
if factory_name not in seen:
seen.add(factory_name)
factory_configs.append(factory_config)
if (not objs) and fid:
if fid == "LocalAI":
mdlnm += "___LocalAI"
elif fid == "HuggingFace":
mdlnm += "___HuggingFace"
elif fid == "OpenAI-API-Compatible":
mdlnm += "___OpenAI-API"
elif fid == "VLLM":
mdlnm += "___VLLM"
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
if not objs:
return
return objs[0]
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
return list(objs)
@staticmethod
def split_model_name_and_factory(model_name):
arr = model_name.split("@")
if len(arr) < 2:
return model_name, None
if len(arr) > 2:
return "@".join(arr[0:-1]), arr[-1]
# model name must be xxx@yyy
try:
model_factories = settings.FACTORY_LLM_INFOS
model_providers = set([f["name"] for f in model_factories])
if arr[-1] not in model_providers:
return model_name, None
return arr[0], arr[-1]
except Exception as e:
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
return model_name, None
@classmethod
@DB.connection_context()
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
raise LookupError("Tenant not found")
if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id if not llm_name else llm_name
elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value:
mdlnm = tenant.img2txt_id if not llm_name else llm_name
elif llm_type == LLMType.CHAT.value:
mdlnm = tenant.llm_id if not llm_name else llm_name
elif llm_type == LLMType.RERANK:
mdlnm = tenant.rerank_id if not llm_name else llm_name
elif llm_type == LLMType.TTS:
mdlnm = tenant.tts_id if not llm_name else llm_name
else:
assert False, "LLM type error"
model_config = cls.get_api_key(tenant_id, mdlnm)
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
if not model_config: # for some cases seems fid mismatch
model_config = cls.get_api_key(tenant_id, mdlnm)
if model_config:
model_config = model_config.to_dict()
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
if not llm and fid: # for some cases seems fid mismatch
llm = LLMService.query(llm_name=mdlnm)
if llm:
model_config["is_tools"] = llm[0].is_tools
if not model_config:
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
if not model_config:
if mdlnm == "flag-embedding":
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
else:
if not mdlnm:
raise LookupError(f"Type of {llm_type} model is not set.")
raise LookupError("Model({}) not authorized".format(mdlnm))
return model_config
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
kwargs.update({"provider": model_config["llm_factory"]})
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.RERANK:
if model_config["llm_factory"] not in RerankModel:
return
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel:
return
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel:
return
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
if llm_type == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel:
return
return TTSModel[model_config["llm_factory"]](
model_config["api_key"],
model_config["llm_name"],
base_url=model_config["api_base"],
for factory_config in factory_configs:
for llm in LLMService.query(fid=factory_config["factory"]):
tenant_llm.append(
{
"tenant_id": user_id,
"llm_factory": factory_config["factory"],
"llm_name": llm.llm_name,
"model_type": llm.model_type,
"api_key": factory_config["api_key"],
"api_base": factory_config["base_url"],
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
}
)
@classmethod
@DB.connection_context()
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
logging.error(f"Tenant not found: {tenant_id}")
return 0
llm_map = {
LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name,
LLMType.SPEECH2TEXT.value: tenant.asr_id,
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
}
mdlnm = llm_map.get(llm_type)
if mdlnm is None:
logging.error(f"LLM type error: {llm_type}")
return 0
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
try:
num = (
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
.execute()
if settings.LIGHTEN != 1:
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
tenant_llm.append(
{
"tenant_id": user_id,
"llm_factory": fid,
"llm_name": mdlnm,
"model_type": "embedding",
"api_key": "",
"api_base": "",
"max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512,
}
)
except Exception:
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
return 0
return num
@classmethod
@DB.connection_context()
def get_openai_models(cls):
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
return list(objs)
@staticmethod
def llm_id2llm_type(llm_id: str) -> str | None:
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
llm_factories = settings.FACTORY_LLM_INFOS
for llm_factory in llm_factories:
for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]:
return llm["model_type"].split(",")[-1]
for llm in LLMService.query(llm_name=llm_id):
return llm.model_type
llm = TenantLLMService.get_or_none(llm_name=llm_id)
if llm:
return llm.model_type
for llm in TenantLLMService.query(llm_name=llm_id):
return llm.model_type
unique = {}
for item in tenant_llm:
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
if key not in unique:
unique[key] = item
return list(unique.values())
class LLMBundle:
class LLMBundle(LLM4Tenant):
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
self.tenant_id = tenant_id
self.llm_type = llm_type
self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
self.max_length = model_config.get("max_tokens", 8192)
self.is_tools = model_config.get("is_tools", False)
self.verbose_tool_use = kwargs.get("verbose_tool_use")
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
self.langfuse = None
if langfuse_keys:
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
if langfuse.auth_check():
self.langfuse = langfuse
trace_id = self.langfuse.create_trace_id()
self.trace_context = {"trace_id": trace_id}
super().__init__(tenant_id, llm_type, llm_name, lang, **kwargs)
def bind_tools(self, toolcall_session, tools):
if not self.is_tools:

View File

@ -0,0 +1,252 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from langfuse import Langfuse
from api import settings
from api.db import LLMType
from api.db.db_models import DB, LLMFactories, TenantLLM
from api.db.services.common_service import CommonService
from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.user_service import TenantService
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
class LLMFactoriesService(CommonService):
model = LLMFactories
class TenantLLMService(CommonService):
model = TenantLLM
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_name):
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
if not fid:
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
else:
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
if (not objs) and fid:
if fid == "LocalAI":
mdlnm += "___LocalAI"
elif fid == "HuggingFace":
mdlnm += "___HuggingFace"
elif fid == "OpenAI-API-Compatible":
mdlnm += "___OpenAI-API"
elif fid == "VLLM":
mdlnm += "___VLLM"
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
if not objs:
return
return objs[0]
@classmethod
@DB.connection_context()
def get_my_llms(cls, tenant_id):
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
return list(objs)
@staticmethod
def split_model_name_and_factory(model_name):
arr = model_name.split("@")
if len(arr) < 2:
return model_name, None
if len(arr) > 2:
return "@".join(arr[0:-1]), arr[-1]
# model name must be xxx@yyy
try:
model_factories = settings.FACTORY_LLM_INFOS
model_providers = set([f["name"] for f in model_factories])
if arr[-1] not in model_providers:
return model_name, None
return arr[0], arr[-1]
except Exception as e:
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
return model_name, None
@classmethod
@DB.connection_context()
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
from api.db.services.llm_service import LLMService
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
raise LookupError("Tenant not found")
if llm_type == LLMType.EMBEDDING.value:
mdlnm = tenant.embd_id if not llm_name else llm_name
elif llm_type == LLMType.SPEECH2TEXT.value:
mdlnm = tenant.asr_id
elif llm_type == LLMType.IMAGE2TEXT.value:
mdlnm = tenant.img2txt_id if not llm_name else llm_name
elif llm_type == LLMType.CHAT.value:
mdlnm = tenant.llm_id if not llm_name else llm_name
elif llm_type == LLMType.RERANK:
mdlnm = tenant.rerank_id if not llm_name else llm_name
elif llm_type == LLMType.TTS:
mdlnm = tenant.tts_id if not llm_name else llm_name
else:
assert False, "LLM type error"
model_config = cls.get_api_key(tenant_id, mdlnm)
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
if not model_config: # for some cases seems fid mismatch
model_config = cls.get_api_key(tenant_id, mdlnm)
if model_config:
model_config = model_config.to_dict()
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
if not llm and fid: # for some cases seems fid mismatch
llm = LLMService.query(llm_name=mdlnm)
if llm:
model_config["is_tools"] = llm[0].is_tools
if not model_config:
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
if not model_config:
if mdlnm == "flag-embedding":
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
else:
if not mdlnm:
raise LookupError(f"Type of {llm_type} model is not set.")
raise LookupError("Model({}) not authorized".format(mdlnm))
return model_config
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
kwargs.update({"provider": model_config["llm_factory"]})
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
return
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.RERANK:
if model_config["llm_factory"] not in RerankModel:
return
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel:
return
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel:
return
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
if llm_type == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel:
return
return TTSModel[model_config["llm_factory"]](
model_config["api_key"],
model_config["llm_name"],
base_url=model_config["api_base"],
)
@classmethod
@DB.connection_context()
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
e, tenant = TenantService.get_by_id(tenant_id)
if not e:
logging.error(f"Tenant not found: {tenant_id}")
return 0
llm_map = {
LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name,
LLMType.SPEECH2TEXT.value: tenant.asr_id,
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
}
mdlnm = llm_map.get(llm_type)
if mdlnm is None:
logging.error(f"LLM type error: {llm_type}")
return 0
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
try:
num = (
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
.execute()
)
except Exception:
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
return 0
return num
@classmethod
@DB.connection_context()
def get_openai_models(cls):
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
return list(objs)
@staticmethod
def llm_id2llm_type(llm_id: str) -> str | None:
from api.db.services.llm_service import LLMService
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
llm_factories = settings.FACTORY_LLM_INFOS
for llm_factory in llm_factories:
for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]:
return llm["model_type"].split(",")[-1]
for llm in LLMService.query(llm_name=llm_id):
return llm.model_type
llm = TenantLLMService.get_or_none(llm_name=llm_id)
if llm:
return llm.model_type
for llm in TenantLLMService.query(llm_name=llm_id):
return llm.model_type
class LLM4Tenant:
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
self.tenant_id = tenant_id
self.llm_type = llm_type
self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
self.max_length = model_config.get("max_tokens", 8192)
self.is_tools = model_config.get("is_tools", False)
self.verbose_tool_use = kwargs.get("verbose_tool_use")
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
self.langfuse = None
if langfuse_keys:
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
if langfuse.auth_check():
self.langfuse = langfuse
trace_id = self.langfuse.create_trace_id()
self.trace_context = {"trace_id": trace_id}

View File

@ -48,7 +48,8 @@ from werkzeug.http import HTTP_STATUS_CODES
from api import settings
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
from api.db.db_models import APIToken
from api.db.services.llm_service import LLMService, TenantLLMService
from api.db.services.llm_service import LLMService
from api.db.services.tenant_llm_service import TenantLLMService
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions

View File

@ -2632,9 +2632,11 @@ data:{
"document_name": "1.txt",
"dataset_id": "8e83e57a884611ef9d760242ac120006",
"image_id": "",
"url": null,
"similarity": 0.7,
"vector_similarity": 0.0,
"term_similarity": 1.0,
"doc_type": [],
"positions": [
""
]
@ -2649,6 +2651,7 @@ data:{
]
},
"prompt": "xxxxxxxxxxx",
"created_at": 1755055623.6401553,
"id": "a84c5dd4-97b4-4624-8c3b-974012c8000d",
"session_id": "82b0ab2a9c1911ef9d870242ac120006"
}
@ -2681,7 +2684,7 @@ Creates a session with an agent.
- Method: POST
- URL: `/api/v1/agents/{agent_id}/sessions?user_id={user_id}`
- Headers:
- `'content-Type: application/json' or 'multipart/form-data'`
- `'content-Type: application/json'
- `'Authorization: Bearer <YOUR_API_KEY>'`
- Body:
- the required parameters:`str`
@ -2701,29 +2704,6 @@ curl --request POST \
}'
```
If the **Begin** component in your agent takes required parameters:
```bash
curl --request POST \
--url http://{address}/api/v1/agents/{agent_id}/sessions \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer <YOUR_API_KEY>' \
--data '{
"lang":"Japanese",
"file":"Who are you"
}'
```
If the **Begin** component in your agent takes required file parameters:
```bash
curl --request POST \
--url http://{address}/api/v1/agents/{agent_id}/sessions?user_id={user_id} \
--header 'Content-Type: multipart/form-data' \
--header 'Authorization: Bearer <YOUR_API_KEY>' \
--form '<FILE_KEY>=@./test1.png'
```
##### Request parameters
- `agent_id`: (*Path parameter*)
@ -2739,101 +2719,190 @@ Success:
{
"code": 0,
"data": {
"agent_id": "b4a39922b76611efaa1a0242ac120006",
"agent_id": "dbb4ed366e8611f09690a55a6daec4ef",
"dsl": {
"answer": [],
"components": {
"Answer:GreenReadersDrum": {
"Message:EightyJobsAsk": {
"downstream": [],
"obj": {
"component_name": "Answer",
"inputs": [],
"output": null,
"params": {}
"component_name": "Message",
"params": {
"content": [
"{begin@var1}{begin@var2}"
],
"debug_inputs": {},
"delay_after_error": 2.0,
"description": "",
"exception_default_value": null,
"exception_goto": null,
"exception_method": null,
"inputs": {},
"max_retries": 0,
"message_history_window_size": 22,
"outputs": {
"content": {
"type": "str",
"value": null
}
},
"stream": true
}
},
"upstream": []
"upstream": [
"begin"
]
},
"begin": {
"downstream": [],
"downstream": [
"Message:EightyJobsAsk"
],
"obj": {
"component_name": "Begin",
"inputs": [],
"output": {},
"params": {}
"params": {
"debug_inputs": {},
"delay_after_error": 2.0,
"description": "",
"enablePrologue": true,
"enable_tips": true,
"exception_default_value": null,
"exception_goto": null,
"exception_method": null,
"inputs": {
"var1": {
"name": "var1",
"optional": false,
"options": [],
"type": "line",
"value": null
},
"var2": {
"name": "var2",
"optional": false,
"options": [],
"type": "line",
"value": null
}
},
"max_retries": 0,
"message_history_window_size": 22,
"mode": "conversational",
"outputs": {},
"prologue": "Hi! I'm your assistant, what can I do for you?",
"tips": "Please fill up the form"
}
},
"upstream": []
}
},
"embed_id": "",
"globals": {
"sys.conversation_turns": 0,
"sys.files": [],
"sys.query": "",
"sys.user_id": ""
},
"graph": {
"edges": [],
"edges": [
{
"data": {
"isHovered": false
},
"id": "xy-edge__beginstart-Message:EightyJobsAskend",
"markerEnd": "logo",
"source": "begin",
"sourceHandle": "start",
"style": {
"stroke": "rgba(151, 154, 171, 1)",
"strokeWidth": 1
},
"target": "Message:EightyJobsAsk",
"targetHandle": "end",
"type": "buttonEdge",
"zIndex": 1001
}
],
"nodes": [
{
"data": {
"form": {
"enablePrologue": true,
"inputs": {
"var1": {
"name": "var1",
"optional": false,
"options": [],
"type": "line"
},
"var2": {
"name": "var2",
"optional": false,
"options": [],
"type": "line"
}
},
"mode": "conversational",
"prologue": "Hi! I'm your assistant, what can I do for you?"
},
"label": "Begin",
"name": "begin"
},
"dragging": false,
"height": 44,
"id": "begin",
"position": {
"x": 53.25688640427177,
"y": 198.37155679786412
"measured": {
"height": 112,
"width": 200
},
"positionAbsolute": {
"x": 53.25688640427177,
"y": 198.37155679786412
"position": {
"x": 270.64098070942583,
"y": -56.320928437811176
},
"selected": false,
"sourcePosition": "left",
"targetPosition": "right",
"type": "beginNode",
"width": 200
"type": "beginNode"
},
{
"data": {
"form": {},
"label": "Answer",
"name": "dialog_0"
"form": {
"content": [
"{begin@var1}{begin@var2}"
]
},
"label": "Message",
"name": "Message_0"
},
"dragging": false,
"height": 44,
"id": "Answer:GreenReadersDrum",
"id": "Message:EightyJobsAsk",
"measured": {
"height": 57,
"width": 200
},
"position": {
"x": 360.43473114516974,
"y": 207.29298425089348
"x": 279.5,
"y": 190
},
"positionAbsolute": {
"x": 360.43473114516974,
"y": 207.29298425089348
},
"selected": false,
"selected": true,
"sourcePosition": "right",
"targetPosition": "left",
"type": "logicNode",
"width": 200
"type": "messageNode"
}
]
},
"history": [],
"memory": [],
"messages": [],
"path": [
[
"begin"
],
[]
],
"reference": []
"path": [],
"retrieval": [],
"task_id": "dbb4ed366e8611f09690a55a6daec4ef"
},
"id": "2581031eb7a311efb5200242ac120005",
"id": "0b02fe80780e11f084adcfdc3ed1d902",
"message": [
{
"content": "Hi! I'm your smart assistant. What can I do for you?",
"content": "Hi! I'm your assistant, what can I do for you?",
"role": "assistant"
}
],
"source": "agent",
"user_id": "69736c5e723611efb51b0242ac120007"
"user_id": "c3fb861af27a11efa69751e139332ced"
}
}
```

View File

@ -105,4 +105,5 @@ REMEMBER:
- Cite FACTS, not opinions or transitions
- Each citation supports the ENTIRE sentence
- When in doubt, ask: "Would a fact-checker need to verify this?"
- Place citations at sentence end, before punctuation
- Place citations at sentence end, before punctuation
- Format likes this is FORBIDDEN: [ID:0, ID:5, ID:...]. It MUST be seperated like, [ID:0][ID:5]...

View File

@ -197,7 +197,7 @@ def question_proposal(chat_mdl, content, topn=3):
def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.llm_service import TenantLLMService
from api.db.services.tenant_llm_service import TenantLLMService
if not chat_mdl:
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
@ -231,7 +231,7 @@ def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_
def cross_languages(tenant_id, llm_id, query, languages=[]):
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.llm_service import TenantLLMService
from api.db.services.tenant_llm_service import TenantLLMService
if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)

View File

@ -191,7 +191,6 @@ class RAGFlowS3:
time.sleep(1)
return
@use_prefix_path
@use_default_bucket
def rm_bucket(self, bucket, *args, **kwargs):
for conn in self.conn:

View File

@ -0,0 +1,46 @@
// https://originui.com/r/comp-23.json
'use client';
import { EyeIcon, EyeOffIcon } from 'lucide-react';
import React, { useId, useState } from 'react';
import { Input, InputProps } from '../ui/input';
export default React.forwardRef<HTMLInputElement, InputProps>(
function PasswordInput({ ...props }, ref) {
const id = useId();
const [isVisible, setIsVisible] = useState<boolean>(false);
const toggleVisibility = () => setIsVisible((prevState) => !prevState);
return (
<div className="*:not-first:mt-2">
{/* <Label htmlFor={id}>Show/hide password input</Label> */}
<div className="relative">
<Input
id={id}
className="pe-9"
placeholder="Password"
type={isVisible ? 'text' : 'password'}
ref={ref}
{...props}
/>
<button
className="text-muted-foreground/80 hover:text-foreground focus-visible:border-ring focus-visible:ring-ring/50 absolute inset-y-0 end-0 flex h-full w-9 items-center justify-center rounded-e-md transition-[color,box-shadow] outline-none focus:z-10 focus-visible:ring-[3px] disabled:pointer-events-none disabled:cursor-not-allowed disabled:opacity-50"
type="button"
onClick={toggleVisibility}
aria-label={isVisible ? 'Hide password' : 'Show password'}
aria-pressed={isVisible}
aria-controls="password"
>
{isVisible ? (
<EyeOffIcon size={16} aria-hidden="true" />
) : (
<EyeIcon size={16} aria-hidden="true" />
)}
</button>
</div>
</div>
);
},
);

View File

@ -2,7 +2,7 @@ import { PropsWithChildren } from 'react';
export function PageHeader({ children }: PropsWithChildren) {
return (
<header className="flex justify-between items-center border-b bg-text-title-invert p-5">
<header className="flex justify-between items-center bg-text-title-invert p-5">
{children}
</header>
);

View File

@ -1,52 +0,0 @@
import { Input } from '@/components/originui/input';
import { EyeIcon, EyeOffIcon } from 'lucide-react';
import { ChangeEvent, forwardRef, useId, useState } from 'react';
type PropType = {
name: string;
value: string;
onBlur: () => void;
onChange: (event: ChangeEvent<HTMLInputElement>) => void;
};
function PasswordInput(props: PropType) {
const id = useId();
const [isVisible, setIsVisible] = useState<boolean>(false);
const toggleVisibility = () => setIsVisible((prevState) => !prevState);
return (
<div className="*:not-first:mt-2 w-full">
{/* <Label htmlFor={id}>Show/hide password input</Label> */}
<div className="relative">
<Input
autoComplete="off"
inputMode="numeric"
id={id}
className="pe-9"
placeholder=""
type={isVisible ? 'text' : 'password'}
value={props.value}
onBlur={props.onBlur}
onChange={(ev) => props.onChange(ev)}
/>
<button
className="text-muted-foreground/80 hover:text-foreground focus-visible:border-ring focus-visible:ring-ring/50 absolute inset-y-0 end-0 flex h-full w-9 items-center justify-center rounded-e-md transition-[color,box-shadow] outline-none focus:z-10 focus-visible:ring-[3px] disabled:pointer-events-none disabled:cursor-not-allowed disabled:opacity-50"
type="button"
onClick={toggleVisibility}
aria-label={isVisible ? 'Hide password' : 'Show password'}
aria-pressed={isVisible}
aria-controls="password"
>
{isVisible ? (
<EyeOffIcon size={16} aria-hidden="true" />
) : (
<EyeIcon size={16} aria-hidden="true" />
)}
</button>
</div>
</div>
);
}
export default forwardRef(PasswordInput);

View File

@ -0,0 +1,51 @@
import { useTranslate } from '@/hooks/common-hooks';
import { useFormContext } from 'react-hook-form';
import PasswordInput from './originui/password-input';
import {
FormControl,
FormDescription,
FormField,
FormItem,
FormLabel,
FormMessage,
} from './ui/form';
interface IProps {
name?: string;
}
export function TavilyFormField({
name = 'prompt_config.tavily_api_key',
}: IProps) {
const form = useFormContext();
const { t } = useTranslate('chat');
return (
<FormField
control={form.control}
name={name}
render={({ field }) => (
<FormItem>
<FormLabel tooltip={t('tavilyApiKeyTip')}>Tavily API Key</FormLabel>
<FormControl>
<PasswordInput
{...field}
placeholder={t('tavilyApiKeyMessage')}
autoComplete="new-password"
></PasswordInput>
</FormControl>
<FormDescription>
<a
href="https://app.tavily.com/home"
target={'_blank'}
rel="noreferrer"
>
{t('tavilyApiKeyHelp')}
</a>
</FormDescription>
<FormMessage />
</FormItem>
)}
/>
);
}

View File

@ -3,6 +3,7 @@
import { FileUploader } from '@/components/file-uploader';
import { KnowledgeBaseFormField } from '@/components/knowledge-base-item';
import { SwitchFormField } from '@/components/switch-fom-field';
import { TavilyFormField } from '@/components/tavily-form-field';
import {
FormControl,
FormField,
@ -105,6 +106,7 @@ export default function ChatBasicSetting() {
name={'prompt_config.tts'}
label={t('tts')}
></SwitchFormField>
<TavilyFormField></TavilyFormField>
<KnowledgeBaseFormField></KnowledgeBaseFormField>
</div>
);

View File

@ -68,8 +68,8 @@ export function ChatSettings({ switchSettingVisible }: ChatSettingsProps) {
}, [data, form]);
return (
<section className="p-5 w-[440px] ">
<div className="flex justify-between items-center text-base">
<section className="p-5 w-[440px] border-l">
<div className="flex justify-between items-center text-base pb-2">
Chat Settings
<X className="size-4 cursor-pointer" onClick={switchSettingVisible} />
</div>

View File

@ -24,6 +24,7 @@ export function useChatSettingSchema() {
optional: z.boolean(),
}),
),
tavily_api_key: z.string().optional(),
});
const formSchema = z.object({

View File

@ -0,0 +1,155 @@
import { NextMessageInput } from '@/components/message-input/next';
import MessageItem from '@/components/message-item';
import { Button } from '@/components/ui/button';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { MessageType } from '@/constants/chat';
import {
useFetchConversation,
useFetchDialog,
useGetChatSearchParams,
} from '@/hooks/use-chat-request';
import { useFetchUserInfo } from '@/hooks/user-setting-hooks';
import { buildMessageUuidWithRole } from '@/utils/chat';
import { Trash2 } from 'lucide-react';
import { useCallback } from 'react';
import {
useGetSendButtonDisabled,
useSendButtonDisabled,
} from '../../hooks/use-button-disabled';
import { useCreateConversationBeforeUploadDocument } from '../../hooks/use-create-conversation';
import { useSendMessage } from '../../hooks/use-send-chat-message';
import { buildMessageItemReference } from '../../utils';
import { useAddChatBox } from '../use-add-box';
type MultipleChatBoxProps = {
controller: AbortController;
chatBoxIds: string[];
} & Pick<ReturnType<typeof useAddChatBox>, 'removeChatBox'>;
type ChatCardProps = { id: string } & Pick<
MultipleChatBoxProps,
'controller' | 'removeChatBox'
>;
function ChatCard({ controller, removeChatBox, id }: ChatCardProps) {
const {
value,
// scrollRef,
messageContainerRef,
sendLoading,
derivedMessages,
handleInputChange,
handlePressEnter,
regenerateMessage,
removeMessageById,
stopOutputMessage,
} = useSendMessage(controller);
const { data: userInfo } = useFetchUserInfo();
const { data: currentDialog } = useFetchDialog();
const { data: conversation } = useFetchConversation();
const handleRemoveChatBox = useCallback(() => {
removeChatBox(id);
}, [id, removeChatBox]);
return (
<Card className="bg-transparent border flex-1">
<CardHeader className="border-b px-5 py-3">
<CardTitle className="flex justify-between items-center">
<div>
<span className="text-base">Card Title</span>
<Button variant={'ghost'} className="ml-2">
GPT-4
</Button>
</div>
<Button variant={'ghost'} onClick={handleRemoveChatBox}>
<Trash2 />
</Button>
</CardTitle>
</CardHeader>
<CardContent>
<div ref={messageContainerRef} className="flex-1 overflow-auto min-h-0">
<div className="w-full">
{derivedMessages?.map((message, i) => {
return (
<MessageItem
loading={
message.role === MessageType.Assistant &&
sendLoading &&
derivedMessages.length - 1 === i
}
key={buildMessageUuidWithRole(message)}
item={message}
nickname={userInfo.nickname}
avatar={userInfo.avatar}
avatarDialog={currentDialog.icon}
reference={buildMessageItemReference(
{
message: derivedMessages,
reference: conversation.reference,
},
message,
)}
// clickDocumentButton={clickDocumentButton}
index={i}
removeMessageById={removeMessageById}
regenerateMessage={regenerateMessage}
sendLoading={sendLoading}
></MessageItem>
);
})}
</div>
{/* <div ref={scrollRef} /> */}
</div>
</CardContent>
</Card>
);
}
export function MultipleChatBox({
controller,
chatBoxIds,
removeChatBox,
}: MultipleChatBoxProps) {
const {
value,
sendLoading,
handleInputChange,
handlePressEnter,
stopOutputMessage,
} = useSendMessage(controller);
const { createConversationBeforeUploadDocument } =
useCreateConversationBeforeUploadDocument();
const { conversationId } = useGetChatSearchParams();
const disabled = useGetSendButtonDisabled();
const sendDisabled = useSendButtonDisabled(value);
return (
<section className="h-full flex flex-col">
<div className="flex gap-4 flex-1 px-5 pb-12">
{chatBoxIds.map((id) => (
<ChatCard
key={id}
controller={controller}
id={id}
removeChatBox={removeChatBox}
></ChatCard>
))}
</div>
<NextMessageInput
disabled={disabled}
sendDisabled={sendDisabled}
sendLoading={sendLoading}
value={value}
onInputChange={handleInputChange}
onPressEnter={handlePressEnter}
conversationId={conversationId}
createConversationBeforeUploadDocument={
createConversationBeforeUploadDocument
}
stopOutputMessage={stopOutputMessage}
/>
</section>
);
}

View File

@ -11,16 +11,16 @@ import { buildMessageUuidWithRole } from '@/utils/chat';
import {
useGetSendButtonDisabled,
useSendButtonDisabled,
} from '../hooks/use-button-disabled';
import { useCreateConversationBeforeUploadDocument } from '../hooks/use-create-conversation';
import { useSendMessage } from '../hooks/use-send-chat-message';
import { buildMessageItemReference } from '../utils';
} from '../../hooks/use-button-disabled';
import { useCreateConversationBeforeUploadDocument } from '../../hooks/use-create-conversation';
import { useSendMessage } from '../../hooks/use-send-chat-message';
import { buildMessageItemReference } from '../../utils';
interface IProps {
controller: AbortController;
}
export function ChatBox({ controller }: IProps) {
export function SingleChatBox({ controller }: IProps) {
const {
value,
// scrollRef,
@ -43,7 +43,7 @@ export function ChatBox({ controller }: IProps) {
const sendDisabled = useSendButtonDisabled(value);
return (
<section className="border-x flex flex-col p-5 flex-1 min-w-0">
<section className="flex flex-col p-5 h-full">
<div ref={messageContainerRef} className="flex-1 overflow-auto min-h-0">
<div className="w-full">
{derivedMessages?.map((message, i) => {

View File

@ -7,14 +7,20 @@ import {
BreadcrumbPage,
BreadcrumbSeparator,
} from '@/components/ui/breadcrumb';
import { Button } from '@/components/ui/button';
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
import { useSetModalState } from '@/hooks/common-hooks';
import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks';
import { useFetchDialog } from '@/hooks/use-chat-request';
import { cn } from '@/lib/utils';
import { Plus } from 'lucide-react';
import { useTranslation } from 'react-i18next';
import { useHandleClickConversationCard } from '../hooks/use-click-card';
import { ChatSettings } from './app-settings/chat-settings';
import { ChatBox } from './chat-box';
import { MultipleChatBox } from './chat-box/multiple-chat-box';
import { SingleChatBox } from './chat-box/single-chat-box';
import { Sessions } from './sessions';
import { useAddChatBox } from './use-add-box';
export default function Chat() {
const { navigateToChatList } = useNavigatePage();
@ -24,9 +30,16 @@ export default function Chat() {
useHandleClickConversationCard();
const { visible: settingVisible, switchVisible: switchSettingVisible } =
useSetModalState(true);
const {
removeChatBox,
addChatBox,
chatBoxIds,
hasSingleChatBox,
hasThreeChatBox,
} = useAddChatBox();
return (
<section className="h-full flex flex-col">
<section className="h-full flex flex-col pr-5">
<PageHeader>
<Breadcrumb>
<BreadcrumbList>
@ -43,18 +56,52 @@ export default function Chat() {
</Breadcrumb>
</PageHeader>
<div className="flex flex-1 min-h-0">
<div className="flex flex-1 min-w-0">
<Sessions
handleConversationCardClick={handleConversationCardClick}
switchSettingVisible={switchSettingVisible}
></Sessions>
<ChatBox controller={controller}></ChatBox>
</div>
{settingVisible && (
<ChatSettings
switchSettingVisible={switchSettingVisible}
></ChatSettings>
)}
<Sessions
handleConversationCardClick={handleConversationCardClick}
switchSettingVisible={switchSettingVisible}
></Sessions>
<Card className="flex-1 min-w-0 bg-transparent border h-full">
<CardContent className="flex p-0 h-full">
<Card className="flex flex-col flex-1 bg-transparent">
<CardHeader
className={cn('p-5', { 'border-b': hasSingleChatBox })}
>
<CardTitle className="flex justify-between items-center">
<div className="text-base">
Card Title
<Button variant={'ghost'} className="ml-2">
GPT-4
</Button>
</div>
<Button
variant={'ghost'}
onClick={addChatBox}
disabled={hasThreeChatBox}
>
<Plus></Plus> Multiple Models
</Button>
</CardTitle>
</CardHeader>
<CardContent className="flex-1 p-0">
{hasSingleChatBox ? (
<SingleChatBox controller={controller}></SingleChatBox>
) : (
<MultipleChatBox
chatBoxIds={chatBoxIds}
controller={controller}
removeChatBox={removeChatBox}
></MultipleChatBox>
)}
</CardContent>
</Card>
{settingVisible && (
<ChatSettings
switchSettingVisible={switchSettingVisible}
></ChatSettings>
)}
</CardContent>
</Card>
</div>
</section>
);

View File

@ -0,0 +1,26 @@
import { useCallback, useState } from 'react';
import { v4 as uuid } from 'uuid';
export function useAddChatBox() {
const [ids, setIds] = useState<string[]>([uuid()]);
const hasSingleChatBox = ids.length === 1;
const hasThreeChatBox = ids.length === 3;
const addChatBox = useCallback(() => {
setIds((prev) => [...prev, uuid()]);
}, []);
const removeChatBox = useCallback((id: string) => {
setIds((prev) => prev.filter((x) => x !== id));
}, []);
return {
chatBoxIds: ids,
hasSingleChatBox,
hasThreeChatBox,
addChatBox,
removeChatBox,
};
}

View File

@ -1,4 +1,4 @@
import PasswordInput from '@/components/password-input';
import PasswordInput from '@/components/originui/password-input';
import { Avatar, AvatarFallback, AvatarImage } from '@/components/ui/avatar';
import { Button } from '@/components/ui/button';
import {