mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-05 01:55:05 +08:00
Compare commits
7 Commits
da5cef0686
...
5e8cd693a5
| Author | SHA1 | Date | |
|---|---|---|---|
| 5e8cd693a5 | |||
| 29f297b850 | |||
| 7235638607 | |||
| 00919fd599 | |||
| 43c0792ffd | |||
| 4b1b68c5fc | |||
| 3492f54c7a |
46
.github/ISSUE_TEMPLATE/agent_scenario_request.yml
vendored
Normal file
46
.github/ISSUE_TEMPLATE/agent_scenario_request.yml
vendored
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
name: "❤️🔥ᴬᴳᴱᴺᵀ Agent scenario request"
|
||||||
|
description: Propose a agent scenario request for RAGFlow.
|
||||||
|
title: "[Agent Scenario Request]: "
|
||||||
|
labels: ["❤️🔥ᴬᴳᴱᴺᵀ agent scenario"]
|
||||||
|
body:
|
||||||
|
- type: checkboxes
|
||||||
|
attributes:
|
||||||
|
label: Self Checks
|
||||||
|
description: "Please check the following in order to be responded in time :)"
|
||||||
|
options:
|
||||||
|
- label: I have searched for existing issues [search for existing issues](https://github.com/infiniflow/ragflow/issues), including closed ones.
|
||||||
|
required: true
|
||||||
|
- label: I confirm that I am using English to submit this report ([Language Policy](https://github.com/infiniflow/ragflow/issues/5910)).
|
||||||
|
required: true
|
||||||
|
- label: Non-english title submitions will be closed directly ( 非英文标题的提交将会被直接关闭 ) ([Language Policy](https://github.com/infiniflow/ragflow/issues/5910)).
|
||||||
|
required: true
|
||||||
|
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Is your feature request related to a scenario?
|
||||||
|
description: |
|
||||||
|
A clear and concise description of what the scenario is. Ex. I'm always frustrated when [...]
|
||||||
|
render: Markdown
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Describe the feature you'd like
|
||||||
|
description: A clear and concise description of what you want to happen.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Documentation, adoption, use case
|
||||||
|
description: If you can, explain some scenarios how users might use this, situations it would be helpful in. Any API designs, mockups, or diagrams are also helpful.
|
||||||
|
render: Markdown
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
|
- type: textarea
|
||||||
|
attributes:
|
||||||
|
label: Additional information
|
||||||
|
description: |
|
||||||
|
Add any other context or screenshots about the feature request here.
|
||||||
|
validations:
|
||||||
|
required: false
|
||||||
@ -24,7 +24,8 @@ from typing import Any
|
|||||||
import json_repair
|
import json_repair
|
||||||
|
|
||||||
from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta
|
from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta
|
||||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.mcp_server_service import MCPServerService
|
from api.db.services.mcp_server_service import MCPServerService
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.prompts import message_fit_in
|
from rag.prompts import message_fit_in
|
||||||
|
|||||||
@ -24,7 +24,8 @@ from copy import deepcopy
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from agent.component.base import ComponentBase, ComponentParamBase
|
from agent.component.base import ComponentBase, ComponentParamBase
|
||||||
from api.utils.api_utils import timeout
|
from api.utils.api_utils import timeout
|
||||||
from rag.prompts import message_fit_in, citation_prompt
|
from rag.prompts import message_fit_in, citation_prompt
|
||||||
|
|||||||
@ -28,8 +28,8 @@ from api.db.db_models import APIToken
|
|||||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||||
from api.db.services.dialog_service import DialogService, ask, chat
|
from api.db.services.dialog_service import DialogService, ask, chat
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle, TenantService
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService, TenantService
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from flask import request
|
|||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
from api.db.services.dialog_service import DialogService
|
from api.db.services.dialog_service import DialogService
|
||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
from api.db.services.llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api import settings
|
from api import settings
|
||||||
|
|||||||
@ -17,7 +17,8 @@ import logging
|
|||||||
import json
|
import json
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||||
|
from api.db.services.llm_service import LLMService
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||||
from api.db import StatusEnum, LLMType
|
from api.db import StatusEnum, LLMType
|
||||||
|
|||||||
@ -21,7 +21,7 @@ from api import settings
|
|||||||
from api.db import StatusEnum
|
from api.db import StatusEnum
|
||||||
from api.db.services.dialog_service import DialogService
|
from api.db.services.dialog_service import DialogService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.user_service import TenantService
|
from api.db.services.user_service import TenantService
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required
|
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required
|
||||||
|
|||||||
@ -32,7 +32,8 @@ from api.db.services.document_service import DocumentService
|
|||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.task_service import TaskService, queue_tasks
|
from api.db.services.task_service import TaskService, queue_tasks
|
||||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
|
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
|
||||||
from rag.app.qa import beAdoc, rmPrefix
|
from rag.app.qa import beAdoc, rmPrefix
|
||||||
|
|||||||
@ -16,10 +16,8 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from flask import Response, jsonify, request
|
from flask import Response, jsonify, request
|
||||||
|
|
||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
from api.db import LLMType, StatusEnum
|
from api.db import LLMType, StatusEnum
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
@ -29,7 +27,6 @@ from api.db.services.canvas_service import completion as agent_completion
|
|||||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
from api.db.services.conversation_service import ConversationService, iframe_completion
|
||||||
from api.db.services.conversation_service import completion as rag_completion
|
from api.db.services.conversation_service import completion as rag_completion
|
||||||
from api.db.services.dialog_service import DialogService, ask, chat
|
from api.db.services.dialog_service import DialogService, ask, chat
|
||||||
from api.db.services.file_service import FileService
|
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.utils import get_uuid
|
from api.utils import get_uuid
|
||||||
@ -69,11 +66,7 @@ def create(tenant_id, chat_id):
|
|||||||
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
|
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
def create_agent_session(tenant_id, agent_id):
|
def create_agent_session(tenant_id, agent_id):
|
||||||
req = request.json
|
user_id = request.args.get("user_id", tenant_id)
|
||||||
if not request.is_json:
|
|
||||||
req = request.form
|
|
||||||
files = request.files
|
|
||||||
user_id = request.args.get("user_id", "")
|
|
||||||
e, cvs = UserCanvasService.get_by_id(agent_id)
|
e, cvs = UserCanvasService.get_by_id(agent_id)
|
||||||
if not e:
|
if not e:
|
||||||
return get_error_data_result("Agent not found.")
|
return get_error_data_result("Agent not found.")
|
||||||
@ -82,46 +75,21 @@ def create_agent_session(tenant_id, agent_id):
|
|||||||
if not isinstance(cvs.dsl, str):
|
if not isinstance(cvs.dsl, str):
|
||||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||||
|
|
||||||
canvas = Canvas(cvs.dsl, tenant_id)
|
session_id=get_uuid()
|
||||||
|
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
||||||
canvas.reset()
|
canvas.reset()
|
||||||
query = canvas.get_preset_param()
|
conv = {
|
||||||
if query:
|
"id": session_id,
|
||||||
for ele in query:
|
"dialog_id": cvs.id,
|
||||||
if not ele["optional"]:
|
"user_id": user_id,
|
||||||
if ele["type"] == "file":
|
"message": [],
|
||||||
if files is None or not files.get(ele["key"]):
|
"source": "agent",
|
||||||
return get_error_data_result(f"`{ele['key']}` with type `{ele['type']}` is required")
|
"dsl": cvs.dsl
|
||||||
upload_file = files.get(ele["key"])
|
}
|
||||||
file_content = FileService.parse_docs([upload_file], user_id)
|
API4ConversationService.save(**conv)
|
||||||
file_name = upload_file.filename
|
|
||||||
ele["value"] = file_name + "\n" + file_content
|
|
||||||
else:
|
|
||||||
if req is None or not req.get(ele["key"]):
|
|
||||||
return get_error_data_result(f"`{ele['key']}` with type `{ele['type']}` is required")
|
|
||||||
ele["value"] = req[ele["key"]]
|
|
||||||
else:
|
|
||||||
if ele["type"] == "file":
|
|
||||||
if files is not None and files.get(ele["key"]):
|
|
||||||
upload_file = files.get(ele["key"])
|
|
||||||
file_content = FileService.parse_docs([upload_file], user_id)
|
|
||||||
file_name = upload_file.filename
|
|
||||||
ele["value"] = file_name + "\n" + file_content
|
|
||||||
else:
|
|
||||||
if "value" in ele:
|
|
||||||
ele.pop("value")
|
|
||||||
else:
|
|
||||||
if req is not None and req.get(ele["key"]):
|
|
||||||
ele["value"] = req[ele["key"]]
|
|
||||||
else:
|
|
||||||
if "value" in ele:
|
|
||||||
ele.pop("value")
|
|
||||||
|
|
||||||
for ans in canvas.run(stream=False):
|
|
||||||
pass
|
|
||||||
|
|
||||||
cvs.dsl = json.loads(str(canvas))
|
cvs.dsl = json.loads(str(canvas))
|
||||||
conv = {"id": get_uuid(), "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
|
conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
|
||||||
API4ConversationService.save(**conv)
|
|
||||||
conv["agent_id"] = conv.pop("dialog_id")
|
conv["agent_id"] = conv.pop("dialog_id")
|
||||||
return get_result(data=conv)
|
return get_result(data=conv)
|
||||||
|
|
||||||
|
|||||||
@ -28,7 +28,7 @@ from api.apps.auth import get_auth_client
|
|||||||
from api.db import FileType, UserTenantRole
|
from api.db import FileType, UserTenantRole
|
||||||
from api.db.db_models import TenantLLM
|
from api.db.db_models import TenantLLM
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.llm_service import LLMService, TenantLLMService
|
from api.db.services.llm_service import TenantLLMService, get_init_tenant_llm
|
||||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||||
from api.utils import (
|
from api.utils import (
|
||||||
current_timestamp,
|
current_timestamp,
|
||||||
@ -619,57 +619,8 @@ def user_register(user_id, user):
|
|||||||
"size": 0,
|
"size": 0,
|
||||||
"location": "",
|
"location": "",
|
||||||
}
|
}
|
||||||
tenant_llm = []
|
|
||||||
|
|
||||||
seen = set()
|
tenant_llm = get_init_tenant_llm(user_id)
|
||||||
factory_configs = []
|
|
||||||
for factory_config in [
|
|
||||||
settings.CHAT_CFG,
|
|
||||||
settings.EMBEDDING_CFG,
|
|
||||||
settings.ASR_CFG,
|
|
||||||
settings.IMAGE2TEXT_CFG,
|
|
||||||
settings.RERANK_CFG,
|
|
||||||
]:
|
|
||||||
factory_name = factory_config["factory"]
|
|
||||||
if factory_name not in seen:
|
|
||||||
seen.add(factory_name)
|
|
||||||
factory_configs.append(factory_config)
|
|
||||||
|
|
||||||
for factory_config in factory_configs:
|
|
||||||
for llm in LLMService.query(fid=factory_config["factory"]):
|
|
||||||
tenant_llm.append(
|
|
||||||
{
|
|
||||||
"tenant_id": user_id,
|
|
||||||
"llm_factory": factory_config["factory"],
|
|
||||||
"llm_name": llm.llm_name,
|
|
||||||
"model_type": llm.model_type,
|
|
||||||
"api_key": factory_config["api_key"],
|
|
||||||
"api_base": factory_config["base_url"],
|
|
||||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if settings.LIGHTEN != 1:
|
|
||||||
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
|
|
||||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
|
|
||||||
tenant_llm.append(
|
|
||||||
{
|
|
||||||
"tenant_id": user_id,
|
|
||||||
"llm_factory": fid,
|
|
||||||
"llm_name": mdlnm,
|
|
||||||
"model_type": "embedding",
|
|
||||||
"api_key": "",
|
|
||||||
"api_base": "",
|
|
||||||
"max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
unique = {}
|
|
||||||
for item in tenant_llm:
|
|
||||||
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
|
|
||||||
if key not in unique:
|
|
||||||
unique[key] = item
|
|
||||||
tenant_llm = list(unique.values())
|
|
||||||
|
|
||||||
if not UserService.save(**user):
|
if not UserService.save(**user):
|
||||||
return
|
return
|
||||||
|
|||||||
@ -27,7 +27,8 @@ from api.db.services import UserService
|
|||||||
from api.db.services.canvas_service import CanvasTemplateService
|
from api.db.services.canvas_service import CanvasTemplateService
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||||
|
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
|
||||||
from api.db.services.user_service import TenantService, UserTenantService
|
from api.db.services.user_service import TenantService, UserTenantService
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.utils.file_utils import get_project_base_directory
|
from api.utils.file_utils import get_project_base_directory
|
||||||
@ -64,43 +65,7 @@ def init_superuser():
|
|||||||
"role": UserTenantRole.OWNER
|
"role": UserTenantRole.OWNER
|
||||||
}
|
}
|
||||||
|
|
||||||
user_id = user_info
|
tenant_llm = get_init_tenant_llm(user_info["id"])
|
||||||
tenant_llm = []
|
|
||||||
|
|
||||||
seen = set()
|
|
||||||
factory_configs = []
|
|
||||||
for factory_config in [
|
|
||||||
settings.CHAT_CFG["factory"],
|
|
||||||
settings.EMBEDDING_CFG["factory"],
|
|
||||||
settings.ASR_CFG["factory"],
|
|
||||||
settings.IMAGE2TEXT_CFG["factory"],
|
|
||||||
settings.RERANK_CFG["factory"],
|
|
||||||
]:
|
|
||||||
factory_name = factory_config["factory"]
|
|
||||||
if factory_name not in seen:
|
|
||||||
seen.add(factory_name)
|
|
||||||
factory_configs.append(factory_config)
|
|
||||||
|
|
||||||
for factory_config in factory_configs:
|
|
||||||
for llm in LLMService.query(fid=factory_config["factory"]):
|
|
||||||
tenant_llm.append(
|
|
||||||
{
|
|
||||||
"tenant_id": user_id,
|
|
||||||
"llm_factory": factory_config["factory"],
|
|
||||||
"llm_name": llm.llm_name,
|
|
||||||
"model_type": llm.model_type,
|
|
||||||
"api_key": factory_config["api_key"],
|
|
||||||
"api_base": factory_config["base_url"],
|
|
||||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
unique = {}
|
|
||||||
for item in tenant_llm:
|
|
||||||
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
|
|
||||||
if key not in unique:
|
|
||||||
unique[key] = item
|
|
||||||
tenant_llm = list(unique.values())
|
|
||||||
|
|
||||||
if not UserService.save(**user_info):
|
if not UserService.save(**user_info):
|
||||||
logging.error("can't init admin.")
|
logging.error("can't init admin.")
|
||||||
|
|||||||
@ -33,7 +33,8 @@ from api.db.services.common_service import CommonService
|
|||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.langfuse_service import TenantLangfuseService
|
from api.db.services.langfuse_service import TenantLangfuseService
|
||||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.utils import current_timestamp, datetime_format
|
from api.utils import current_timestamp, datetime_format
|
||||||
from rag.app.resume import forbidden_select_fields4resume
|
from rag.app.resume import forbidden_select_fields4resume
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
@ -365,8 +366,12 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
if dialog.meta_data_filter.get("method") == "auto":
|
if dialog.meta_data_filter.get("method") == "auto":
|
||||||
filters = gen_meta_filter(chat_mdl, metas, questions[-1])
|
filters = gen_meta_filter(chat_mdl, metas, questions[-1])
|
||||||
attachments.extend(meta_filter(metas, filters))
|
attachments.extend(meta_filter(metas, filters))
|
||||||
|
if not attachments:
|
||||||
|
attachments = None
|
||||||
elif dialog.meta_data_filter.get("method") == "manual":
|
elif dialog.meta_data_filter.get("method") == "manual":
|
||||||
attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"]))
|
attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"]))
|
||||||
|
if not attachments:
|
||||||
|
attachments = None
|
||||||
|
|
||||||
if prompt_config.get("keyword", False):
|
if prompt_config.get("keyword", False):
|
||||||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
||||||
@ -375,17 +380,16 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
|
|
||||||
thought = ""
|
thought = ""
|
||||||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||||||
|
knowledges = []
|
||||||
|
|
||||||
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]:
|
||||||
knowledges = []
|
|
||||||
else:
|
|
||||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||||
knowledges = []
|
knowledges = []
|
||||||
if prompt_config.get("reasoning", False):
|
if prompt_config.get("reasoning", False):
|
||||||
reasoner = DeepResearcher(
|
reasoner = DeepResearcher(
|
||||||
chat_mdl,
|
chat_mdl,
|
||||||
prompt_config,
|
prompt_config,
|
||||||
partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3),
|
partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3, doc_ids=attachments),
|
||||||
)
|
)
|
||||||
|
|
||||||
for think in reasoner.thinking(kbinfos, " ".join(questions)):
|
for think in reasoner.thinking(kbinfos, " ".join(questions)):
|
||||||
|
|||||||
@ -18,246 +18,73 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
from api.db.db_models import LLM
|
||||||
from langfuse import Langfuse
|
|
||||||
|
|
||||||
from api import settings
|
|
||||||
from api.db import LLMType
|
|
||||||
from api.db.db_models import DB, LLM, LLMFactories, TenantLLM
|
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.langfuse_service import TenantLangfuseService
|
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
||||||
from api.db.services.user_service import TenantService
|
|
||||||
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
|
|
||||||
|
|
||||||
|
|
||||||
class LLMFactoriesService(CommonService):
|
|
||||||
model = LLMFactories
|
|
||||||
|
|
||||||
|
|
||||||
class LLMService(CommonService):
|
class LLMService(CommonService):
|
||||||
model = LLM
|
model = LLM
|
||||||
|
|
||||||
|
|
||||||
class TenantLLMService(CommonService):
|
def get_init_tenant_llm(user_id):
|
||||||
model = TenantLLM
|
from api import settings
|
||||||
|
tenant_llm = []
|
||||||
|
|
||||||
@classmethod
|
seen = set()
|
||||||
@DB.connection_context()
|
factory_configs = []
|
||||||
def get_api_key(cls, tenant_id, model_name):
|
for factory_config in [
|
||||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
settings.CHAT_CFG,
|
||||||
if not fid:
|
settings.EMBEDDING_CFG,
|
||||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
|
settings.ASR_CFG,
|
||||||
else:
|
settings.IMAGE2TEXT_CFG,
|
||||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
settings.RERANK_CFG,
|
||||||
|
]:
|
||||||
|
factory_name = factory_config["factory"]
|
||||||
|
if factory_name not in seen:
|
||||||
|
seen.add(factory_name)
|
||||||
|
factory_configs.append(factory_config)
|
||||||
|
|
||||||
if (not objs) and fid:
|
for factory_config in factory_configs:
|
||||||
if fid == "LocalAI":
|
for llm in LLMService.query(fid=factory_config["factory"]):
|
||||||
mdlnm += "___LocalAI"
|
tenant_llm.append(
|
||||||
elif fid == "HuggingFace":
|
{
|
||||||
mdlnm += "___HuggingFace"
|
"tenant_id": user_id,
|
||||||
elif fid == "OpenAI-API-Compatible":
|
"llm_factory": factory_config["factory"],
|
||||||
mdlnm += "___OpenAI-API"
|
"llm_name": llm.llm_name,
|
||||||
elif fid == "VLLM":
|
"model_type": llm.model_type,
|
||||||
mdlnm += "___VLLM"
|
"api_key": factory_config["api_key"],
|
||||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
"api_base": factory_config["base_url"],
|
||||||
if not objs:
|
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||||
return
|
}
|
||||||
return objs[0]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@DB.connection_context()
|
|
||||||
def get_my_llms(cls, tenant_id):
|
|
||||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
|
|
||||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
|
||||||
|
|
||||||
return list(objs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def split_model_name_and_factory(model_name):
|
|
||||||
arr = model_name.split("@")
|
|
||||||
if len(arr) < 2:
|
|
||||||
return model_name, None
|
|
||||||
if len(arr) > 2:
|
|
||||||
return "@".join(arr[0:-1]), arr[-1]
|
|
||||||
|
|
||||||
# model name must be xxx@yyy
|
|
||||||
try:
|
|
||||||
model_factories = settings.FACTORY_LLM_INFOS
|
|
||||||
model_providers = set([f["name"] for f in model_factories])
|
|
||||||
if arr[-1] not in model_providers:
|
|
||||||
return model_name, None
|
|
||||||
return arr[0], arr[-1]
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
|
|
||||||
return model_name, None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@DB.connection_context()
|
|
||||||
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
|
|
||||||
e, tenant = TenantService.get_by_id(tenant_id)
|
|
||||||
if not e:
|
|
||||||
raise LookupError("Tenant not found")
|
|
||||||
|
|
||||||
if llm_type == LLMType.EMBEDDING.value:
|
|
||||||
mdlnm = tenant.embd_id if not llm_name else llm_name
|
|
||||||
elif llm_type == LLMType.SPEECH2TEXT.value:
|
|
||||||
mdlnm = tenant.asr_id
|
|
||||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
|
||||||
mdlnm = tenant.img2txt_id if not llm_name else llm_name
|
|
||||||
elif llm_type == LLMType.CHAT.value:
|
|
||||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
|
||||||
elif llm_type == LLMType.RERANK:
|
|
||||||
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
|
||||||
elif llm_type == LLMType.TTS:
|
|
||||||
mdlnm = tenant.tts_id if not llm_name else llm_name
|
|
||||||
else:
|
|
||||||
assert False, "LLM type error"
|
|
||||||
|
|
||||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
|
||||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
|
||||||
if not model_config: # for some cases seems fid mismatch
|
|
||||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
|
||||||
if model_config:
|
|
||||||
model_config = model_config.to_dict()
|
|
||||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
|
||||||
if not llm and fid: # for some cases seems fid mismatch
|
|
||||||
llm = LLMService.query(llm_name=mdlnm)
|
|
||||||
if llm:
|
|
||||||
model_config["is_tools"] = llm[0].is_tools
|
|
||||||
if not model_config:
|
|
||||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
|
||||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
|
||||||
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
|
||||||
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
|
||||||
if not model_config:
|
|
||||||
if mdlnm == "flag-embedding":
|
|
||||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
|
|
||||||
else:
|
|
||||||
if not mdlnm:
|
|
||||||
raise LookupError(f"Type of {llm_type} model is not set.")
|
|
||||||
raise LookupError("Model({}) not authorized".format(mdlnm))
|
|
||||||
return model_config
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@DB.connection_context()
|
|
||||||
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
|
||||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
|
||||||
kwargs.update({"provider": model_config["llm_factory"]})
|
|
||||||
if llm_type == LLMType.EMBEDDING.value:
|
|
||||||
if model_config["llm_factory"] not in EmbeddingModel:
|
|
||||||
return
|
|
||||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
|
||||||
|
|
||||||
if llm_type == LLMType.RERANK:
|
|
||||||
if model_config["llm_factory"] not in RerankModel:
|
|
||||||
return
|
|
||||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
|
||||||
|
|
||||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
|
||||||
if model_config["llm_factory"] not in CvModel:
|
|
||||||
return
|
|
||||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
|
||||||
|
|
||||||
if llm_type == LLMType.CHAT.value:
|
|
||||||
if model_config["llm_factory"] not in ChatModel:
|
|
||||||
return
|
|
||||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
|
||||||
|
|
||||||
if llm_type == LLMType.SPEECH2TEXT:
|
|
||||||
if model_config["llm_factory"] not in Seq2txtModel:
|
|
||||||
return
|
|
||||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
|
||||||
if llm_type == LLMType.TTS:
|
|
||||||
if model_config["llm_factory"] not in TTSModel:
|
|
||||||
return
|
|
||||||
return TTSModel[model_config["llm_factory"]](
|
|
||||||
model_config["api_key"],
|
|
||||||
model_config["llm_name"],
|
|
||||||
base_url=model_config["api_base"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
if settings.LIGHTEN != 1:
|
||||||
@DB.connection_context()
|
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
|
||||||
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
|
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
|
||||||
e, tenant = TenantService.get_by_id(tenant_id)
|
tenant_llm.append(
|
||||||
if not e:
|
{
|
||||||
logging.error(f"Tenant not found: {tenant_id}")
|
"tenant_id": user_id,
|
||||||
return 0
|
"llm_factory": fid,
|
||||||
|
"llm_name": mdlnm,
|
||||||
llm_map = {
|
"model_type": "embedding",
|
||||||
LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name,
|
"api_key": "",
|
||||||
LLMType.SPEECH2TEXT.value: tenant.asr_id,
|
"api_base": "",
|
||||||
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
|
"max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512,
|
||||||
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
}
|
||||||
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
|
||||||
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
|
|
||||||
}
|
|
||||||
|
|
||||||
mdlnm = llm_map.get(llm_type)
|
|
||||||
if mdlnm is None:
|
|
||||||
logging.error(f"LLM type error: {llm_type}")
|
|
||||||
return 0
|
|
||||||
|
|
||||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
|
||||||
|
|
||||||
try:
|
|
||||||
num = (
|
|
||||||
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
|
||||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
|
||||||
.execute()
|
|
||||||
)
|
)
|
||||||
except Exception:
|
|
||||||
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
return num
|
unique = {}
|
||||||
|
for item in tenant_llm:
|
||||||
@classmethod
|
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
|
||||||
@DB.connection_context()
|
if key not in unique:
|
||||||
def get_openai_models(cls):
|
unique[key] = item
|
||||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
return list(unique.values())
|
||||||
return list(objs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def llm_id2llm_type(llm_id: str) -> str | None:
|
|
||||||
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
|
||||||
llm_factories = settings.FACTORY_LLM_INFOS
|
|
||||||
for llm_factory in llm_factories:
|
|
||||||
for llm in llm_factory["llm"]:
|
|
||||||
if llm_id == llm["llm_name"]:
|
|
||||||
return llm["model_type"].split(",")[-1]
|
|
||||||
|
|
||||||
for llm in LLMService.query(llm_name=llm_id):
|
|
||||||
return llm.model_type
|
|
||||||
|
|
||||||
llm = TenantLLMService.get_or_none(llm_name=llm_id)
|
|
||||||
if llm:
|
|
||||||
return llm.model_type
|
|
||||||
for llm in TenantLLMService.query(llm_name=llm_id):
|
|
||||||
return llm.model_type
|
|
||||||
|
|
||||||
|
|
||||||
class LLMBundle:
|
class LLMBundle(LLM4Tenant):
|
||||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||||
self.tenant_id = tenant_id
|
super().__init__(tenant_id, llm_type, llm_name, lang, **kwargs)
|
||||||
self.llm_type = llm_type
|
|
||||||
self.llm_name = llm_name
|
|
||||||
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
|
|
||||||
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
|
||||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
|
||||||
self.max_length = model_config.get("max_tokens", 8192)
|
|
||||||
|
|
||||||
self.is_tools = model_config.get("is_tools", False)
|
|
||||||
self.verbose_tool_use = kwargs.get("verbose_tool_use")
|
|
||||||
|
|
||||||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
|
||||||
self.langfuse = None
|
|
||||||
if langfuse_keys:
|
|
||||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
|
||||||
if langfuse.auth_check():
|
|
||||||
self.langfuse = langfuse
|
|
||||||
trace_id = self.langfuse.create_trace_id()
|
|
||||||
self.trace_context = {"trace_id": trace_id}
|
|
||||||
|
|
||||||
def bind_tools(self, toolcall_session, tools):
|
def bind_tools(self, toolcall_session, tools):
|
||||||
if not self.is_tools:
|
if not self.is_tools:
|
||||||
|
|||||||
252
api/db/services/tenant_llm_service.py
Normal file
252
api/db/services/tenant_llm_service.py
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
import logging
|
||||||
|
from langfuse import Langfuse
|
||||||
|
from api import settings
|
||||||
|
from api.db import LLMType
|
||||||
|
from api.db.db_models import DB, LLMFactories, TenantLLM
|
||||||
|
from api.db.services.common_service import CommonService
|
||||||
|
from api.db.services.langfuse_service import TenantLangfuseService
|
||||||
|
from api.db.services.user_service import TenantService
|
||||||
|
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
|
||||||
|
|
||||||
|
|
||||||
|
class LLMFactoriesService(CommonService):
|
||||||
|
model = LLMFactories
|
||||||
|
|
||||||
|
|
||||||
|
class TenantLLMService(CommonService):
|
||||||
|
model = TenantLLM
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_api_key(cls, tenant_id, model_name):
|
||||||
|
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
||||||
|
if not fid:
|
||||||
|
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
|
||||||
|
else:
|
||||||
|
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||||
|
|
||||||
|
if (not objs) and fid:
|
||||||
|
if fid == "LocalAI":
|
||||||
|
mdlnm += "___LocalAI"
|
||||||
|
elif fid == "HuggingFace":
|
||||||
|
mdlnm += "___HuggingFace"
|
||||||
|
elif fid == "OpenAI-API-Compatible":
|
||||||
|
mdlnm += "___OpenAI-API"
|
||||||
|
elif fid == "VLLM":
|
||||||
|
mdlnm += "___VLLM"
|
||||||
|
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||||
|
if not objs:
|
||||||
|
return
|
||||||
|
return objs[0]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_my_llms(cls, tenant_id):
|
||||||
|
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
|
||||||
|
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||||
|
|
||||||
|
return list(objs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def split_model_name_and_factory(model_name):
|
||||||
|
arr = model_name.split("@")
|
||||||
|
if len(arr) < 2:
|
||||||
|
return model_name, None
|
||||||
|
if len(arr) > 2:
|
||||||
|
return "@".join(arr[0:-1]), arr[-1]
|
||||||
|
|
||||||
|
# model name must be xxx@yyy
|
||||||
|
try:
|
||||||
|
model_factories = settings.FACTORY_LLM_INFOS
|
||||||
|
model_providers = set([f["name"] for f in model_factories])
|
||||||
|
if arr[-1] not in model_providers:
|
||||||
|
return model_name, None
|
||||||
|
return arr[0], arr[-1]
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
|
||||||
|
return model_name, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
|
||||||
|
from api.db.services.llm_service import LLMService
|
||||||
|
e, tenant = TenantService.get_by_id(tenant_id)
|
||||||
|
if not e:
|
||||||
|
raise LookupError("Tenant not found")
|
||||||
|
|
||||||
|
if llm_type == LLMType.EMBEDDING.value:
|
||||||
|
mdlnm = tenant.embd_id if not llm_name else llm_name
|
||||||
|
elif llm_type == LLMType.SPEECH2TEXT.value:
|
||||||
|
mdlnm = tenant.asr_id
|
||||||
|
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||||
|
mdlnm = tenant.img2txt_id if not llm_name else llm_name
|
||||||
|
elif llm_type == LLMType.CHAT.value:
|
||||||
|
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||||
|
elif llm_type == LLMType.RERANK:
|
||||||
|
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
||||||
|
elif llm_type == LLMType.TTS:
|
||||||
|
mdlnm = tenant.tts_id if not llm_name else llm_name
|
||||||
|
else:
|
||||||
|
assert False, "LLM type error"
|
||||||
|
|
||||||
|
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||||
|
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||||
|
if not model_config: # for some cases seems fid mismatch
|
||||||
|
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||||
|
if model_config:
|
||||||
|
model_config = model_config.to_dict()
|
||||||
|
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||||
|
if not llm and fid: # for some cases seems fid mismatch
|
||||||
|
llm = LLMService.query(llm_name=mdlnm)
|
||||||
|
if llm:
|
||||||
|
model_config["is_tools"] = llm[0].is_tools
|
||||||
|
if not model_config:
|
||||||
|
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||||
|
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||||
|
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
||||||
|
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
||||||
|
if not model_config:
|
||||||
|
if mdlnm == "flag-embedding":
|
||||||
|
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
|
||||||
|
else:
|
||||||
|
if not mdlnm:
|
||||||
|
raise LookupError(f"Type of {llm_type} model is not set.")
|
||||||
|
raise LookupError("Model({}) not authorized".format(mdlnm))
|
||||||
|
return model_config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||||
|
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||||
|
kwargs.update({"provider": model_config["llm_factory"]})
|
||||||
|
if llm_type == LLMType.EMBEDDING.value:
|
||||||
|
if model_config["llm_factory"] not in EmbeddingModel:
|
||||||
|
return
|
||||||
|
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||||
|
|
||||||
|
if llm_type == LLMType.RERANK:
|
||||||
|
if model_config["llm_factory"] not in RerankModel:
|
||||||
|
return
|
||||||
|
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||||
|
|
||||||
|
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||||
|
if model_config["llm_factory"] not in CvModel:
|
||||||
|
return
|
||||||
|
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
||||||
|
|
||||||
|
if llm_type == LLMType.CHAT.value:
|
||||||
|
if model_config["llm_factory"] not in ChatModel:
|
||||||
|
return
|
||||||
|
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
||||||
|
|
||||||
|
if llm_type == LLMType.SPEECH2TEXT:
|
||||||
|
if model_config["llm_factory"] not in Seq2txtModel:
|
||||||
|
return
|
||||||
|
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
||||||
|
if llm_type == LLMType.TTS:
|
||||||
|
if model_config["llm_factory"] not in TTSModel:
|
||||||
|
return
|
||||||
|
return TTSModel[model_config["llm_factory"]](
|
||||||
|
model_config["api_key"],
|
||||||
|
model_config["llm_name"],
|
||||||
|
base_url=model_config["api_base"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
|
||||||
|
e, tenant = TenantService.get_by_id(tenant_id)
|
||||||
|
if not e:
|
||||||
|
logging.error(f"Tenant not found: {tenant_id}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
llm_map = {
|
||||||
|
LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name,
|
||||||
|
LLMType.SPEECH2TEXT.value: tenant.asr_id,
|
||||||
|
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
|
||||||
|
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
||||||
|
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
||||||
|
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
mdlnm = llm_map.get(llm_type)
|
||||||
|
if mdlnm is None:
|
||||||
|
logging.error(f"LLM type error: {llm_type}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||||
|
|
||||||
|
try:
|
||||||
|
num = (
|
||||||
|
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
||||||
|
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
||||||
|
.execute()
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return num
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_openai_models(cls):
|
||||||
|
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||||
|
return list(objs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def llm_id2llm_type(llm_id: str) -> str | None:
|
||||||
|
from api.db.services.llm_service import LLMService
|
||||||
|
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||||
|
llm_factories = settings.FACTORY_LLM_INFOS
|
||||||
|
for llm_factory in llm_factories:
|
||||||
|
for llm in llm_factory["llm"]:
|
||||||
|
if llm_id == llm["llm_name"]:
|
||||||
|
return llm["model_type"].split(",")[-1]
|
||||||
|
|
||||||
|
for llm in LLMService.query(llm_name=llm_id):
|
||||||
|
return llm.model_type
|
||||||
|
|
||||||
|
llm = TenantLLMService.get_or_none(llm_name=llm_id)
|
||||||
|
if llm:
|
||||||
|
return llm.model_type
|
||||||
|
for llm in TenantLLMService.query(llm_name=llm_id):
|
||||||
|
return llm.model_type
|
||||||
|
|
||||||
|
|
||||||
|
class LLM4Tenant:
|
||||||
|
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.llm_type = llm_type
|
||||||
|
self.llm_name = llm_name
|
||||||
|
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
|
||||||
|
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
||||||
|
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||||
|
self.max_length = model_config.get("max_tokens", 8192)
|
||||||
|
|
||||||
|
self.is_tools = model_config.get("is_tools", False)
|
||||||
|
self.verbose_tool_use = kwargs.get("verbose_tool_use")
|
||||||
|
|
||||||
|
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
||||||
|
self.langfuse = None
|
||||||
|
if langfuse_keys:
|
||||||
|
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
||||||
|
if langfuse.auth_check():
|
||||||
|
self.langfuse = langfuse
|
||||||
|
trace_id = self.langfuse.create_trace_id()
|
||||||
|
self.trace_context = {"trace_id": trace_id}
|
||||||
@ -48,7 +48,8 @@ from werkzeug.http import HTTP_STATUS_CODES
|
|||||||
from api import settings
|
from api import settings
|
||||||
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
from api.db.services.llm_service import LLMService, TenantLLMService
|
from api.db.services.llm_service import LLMService
|
||||||
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
||||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||||
|
|
||||||
|
|||||||
@ -2632,9 +2632,11 @@ data:{
|
|||||||
"document_name": "1.txt",
|
"document_name": "1.txt",
|
||||||
"dataset_id": "8e83e57a884611ef9d760242ac120006",
|
"dataset_id": "8e83e57a884611ef9d760242ac120006",
|
||||||
"image_id": "",
|
"image_id": "",
|
||||||
|
"url": null,
|
||||||
"similarity": 0.7,
|
"similarity": 0.7,
|
||||||
"vector_similarity": 0.0,
|
"vector_similarity": 0.0,
|
||||||
"term_similarity": 1.0,
|
"term_similarity": 1.0,
|
||||||
|
"doc_type": [],
|
||||||
"positions": [
|
"positions": [
|
||||||
""
|
""
|
||||||
]
|
]
|
||||||
@ -2649,6 +2651,7 @@ data:{
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"prompt": "xxxxxxxxxxx",
|
"prompt": "xxxxxxxxxxx",
|
||||||
|
"created_at": 1755055623.6401553,
|
||||||
"id": "a84c5dd4-97b4-4624-8c3b-974012c8000d",
|
"id": "a84c5dd4-97b4-4624-8c3b-974012c8000d",
|
||||||
"session_id": "82b0ab2a9c1911ef9d870242ac120006"
|
"session_id": "82b0ab2a9c1911ef9d870242ac120006"
|
||||||
}
|
}
|
||||||
@ -2681,7 +2684,7 @@ Creates a session with an agent.
|
|||||||
- Method: POST
|
- Method: POST
|
||||||
- URL: `/api/v1/agents/{agent_id}/sessions?user_id={user_id}`
|
- URL: `/api/v1/agents/{agent_id}/sessions?user_id={user_id}`
|
||||||
- Headers:
|
- Headers:
|
||||||
- `'content-Type: application/json' or 'multipart/form-data'`
|
- `'content-Type: application/json'
|
||||||
- `'Authorization: Bearer <YOUR_API_KEY>'`
|
- `'Authorization: Bearer <YOUR_API_KEY>'`
|
||||||
- Body:
|
- Body:
|
||||||
- the required parameters:`str`
|
- the required parameters:`str`
|
||||||
@ -2701,29 +2704,6 @@ curl --request POST \
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
If the **Begin** component in your agent takes required parameters:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl --request POST \
|
|
||||||
--url http://{address}/api/v1/agents/{agent_id}/sessions \
|
|
||||||
--header 'Content-Type: application/json' \
|
|
||||||
--header 'Authorization: Bearer <YOUR_API_KEY>' \
|
|
||||||
--data '{
|
|
||||||
"lang":"Japanese",
|
|
||||||
"file":"Who are you"
|
|
||||||
}'
|
|
||||||
```
|
|
||||||
|
|
||||||
If the **Begin** component in your agent takes required file parameters:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl --request POST \
|
|
||||||
--url http://{address}/api/v1/agents/{agent_id}/sessions?user_id={user_id} \
|
|
||||||
--header 'Content-Type: multipart/form-data' \
|
|
||||||
--header 'Authorization: Bearer <YOUR_API_KEY>' \
|
|
||||||
--form '<FILE_KEY>=@./test1.png'
|
|
||||||
```
|
|
||||||
|
|
||||||
##### Request parameters
|
##### Request parameters
|
||||||
|
|
||||||
- `agent_id`: (*Path parameter*)
|
- `agent_id`: (*Path parameter*)
|
||||||
@ -2739,101 +2719,190 @@ Success:
|
|||||||
{
|
{
|
||||||
"code": 0,
|
"code": 0,
|
||||||
"data": {
|
"data": {
|
||||||
"agent_id": "b4a39922b76611efaa1a0242ac120006",
|
"agent_id": "dbb4ed366e8611f09690a55a6daec4ef",
|
||||||
"dsl": {
|
"dsl": {
|
||||||
"answer": [],
|
|
||||||
"components": {
|
"components": {
|
||||||
"Answer:GreenReadersDrum": {
|
"Message:EightyJobsAsk": {
|
||||||
"downstream": [],
|
"downstream": [],
|
||||||
"obj": {
|
"obj": {
|
||||||
"component_name": "Answer",
|
"component_name": "Message",
|
||||||
"inputs": [],
|
"params": {
|
||||||
"output": null,
|
"content": [
|
||||||
"params": {}
|
"{begin@var1}{begin@var2}"
|
||||||
|
],
|
||||||
|
"debug_inputs": {},
|
||||||
|
"delay_after_error": 2.0,
|
||||||
|
"description": "",
|
||||||
|
"exception_default_value": null,
|
||||||
|
"exception_goto": null,
|
||||||
|
"exception_method": null,
|
||||||
|
"inputs": {},
|
||||||
|
"max_retries": 0,
|
||||||
|
"message_history_window_size": 22,
|
||||||
|
"outputs": {
|
||||||
|
"content": {
|
||||||
|
"type": "str",
|
||||||
|
"value": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"stream": true
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"upstream": []
|
"upstream": [
|
||||||
|
"begin"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"begin": {
|
"begin": {
|
||||||
"downstream": [],
|
"downstream": [
|
||||||
|
"Message:EightyJobsAsk"
|
||||||
|
],
|
||||||
"obj": {
|
"obj": {
|
||||||
"component_name": "Begin",
|
"component_name": "Begin",
|
||||||
"inputs": [],
|
"params": {
|
||||||
"output": {},
|
"debug_inputs": {},
|
||||||
"params": {}
|
"delay_after_error": 2.0,
|
||||||
|
"description": "",
|
||||||
|
"enablePrologue": true,
|
||||||
|
"enable_tips": true,
|
||||||
|
"exception_default_value": null,
|
||||||
|
"exception_goto": null,
|
||||||
|
"exception_method": null,
|
||||||
|
"inputs": {
|
||||||
|
"var1": {
|
||||||
|
"name": "var1",
|
||||||
|
"optional": false,
|
||||||
|
"options": [],
|
||||||
|
"type": "line",
|
||||||
|
"value": null
|
||||||
|
},
|
||||||
|
"var2": {
|
||||||
|
"name": "var2",
|
||||||
|
"optional": false,
|
||||||
|
"options": [],
|
||||||
|
"type": "line",
|
||||||
|
"value": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"max_retries": 0,
|
||||||
|
"message_history_window_size": 22,
|
||||||
|
"mode": "conversational",
|
||||||
|
"outputs": {},
|
||||||
|
"prologue": "Hi! I'm your assistant, what can I do for you?",
|
||||||
|
"tips": "Please fill up the form"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"upstream": []
|
"upstream": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"embed_id": "",
|
"globals": {
|
||||||
|
"sys.conversation_turns": 0,
|
||||||
|
"sys.files": [],
|
||||||
|
"sys.query": "",
|
||||||
|
"sys.user_id": ""
|
||||||
|
},
|
||||||
"graph": {
|
"graph": {
|
||||||
"edges": [],
|
"edges": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"isHovered": false
|
||||||
|
},
|
||||||
|
"id": "xy-edge__beginstart-Message:EightyJobsAskend",
|
||||||
|
"markerEnd": "logo",
|
||||||
|
"source": "begin",
|
||||||
|
"sourceHandle": "start",
|
||||||
|
"style": {
|
||||||
|
"stroke": "rgba(151, 154, 171, 1)",
|
||||||
|
"strokeWidth": 1
|
||||||
|
},
|
||||||
|
"target": "Message:EightyJobsAsk",
|
||||||
|
"targetHandle": "end",
|
||||||
|
"type": "buttonEdge",
|
||||||
|
"zIndex": 1001
|
||||||
|
}
|
||||||
|
],
|
||||||
"nodes": [
|
"nodes": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
|
"form": {
|
||||||
|
"enablePrologue": true,
|
||||||
|
"inputs": {
|
||||||
|
"var1": {
|
||||||
|
"name": "var1",
|
||||||
|
"optional": false,
|
||||||
|
"options": [],
|
||||||
|
"type": "line"
|
||||||
|
},
|
||||||
|
"var2": {
|
||||||
|
"name": "var2",
|
||||||
|
"optional": false,
|
||||||
|
"options": [],
|
||||||
|
"type": "line"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"mode": "conversational",
|
||||||
|
"prologue": "Hi! I'm your assistant, what can I do for you?"
|
||||||
|
},
|
||||||
"label": "Begin",
|
"label": "Begin",
|
||||||
"name": "begin"
|
"name": "begin"
|
||||||
},
|
},
|
||||||
"dragging": false,
|
"dragging": false,
|
||||||
"height": 44,
|
|
||||||
"id": "begin",
|
"id": "begin",
|
||||||
"position": {
|
"measured": {
|
||||||
"x": 53.25688640427177,
|
"height": 112,
|
||||||
"y": 198.37155679786412
|
"width": 200
|
||||||
},
|
},
|
||||||
"positionAbsolute": {
|
"position": {
|
||||||
"x": 53.25688640427177,
|
"x": 270.64098070942583,
|
||||||
"y": 198.37155679786412
|
"y": -56.320928437811176
|
||||||
},
|
},
|
||||||
"selected": false,
|
"selected": false,
|
||||||
"sourcePosition": "left",
|
"sourcePosition": "left",
|
||||||
"targetPosition": "right",
|
"targetPosition": "right",
|
||||||
"type": "beginNode",
|
"type": "beginNode"
|
||||||
"width": 200
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"form": {},
|
"form": {
|
||||||
"label": "Answer",
|
"content": [
|
||||||
"name": "dialog_0"
|
"{begin@var1}{begin@var2}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"label": "Message",
|
||||||
|
"name": "Message_0"
|
||||||
},
|
},
|
||||||
"dragging": false,
|
"dragging": false,
|
||||||
"height": 44,
|
"id": "Message:EightyJobsAsk",
|
||||||
"id": "Answer:GreenReadersDrum",
|
"measured": {
|
||||||
|
"height": 57,
|
||||||
|
"width": 200
|
||||||
|
},
|
||||||
"position": {
|
"position": {
|
||||||
"x": 360.43473114516974,
|
"x": 279.5,
|
||||||
"y": 207.29298425089348
|
"y": 190
|
||||||
},
|
},
|
||||||
"positionAbsolute": {
|
"selected": true,
|
||||||
"x": 360.43473114516974,
|
|
||||||
"y": 207.29298425089348
|
|
||||||
},
|
|
||||||
"selected": false,
|
|
||||||
"sourcePosition": "right",
|
"sourcePosition": "right",
|
||||||
"targetPosition": "left",
|
"targetPosition": "left",
|
||||||
"type": "logicNode",
|
"type": "messageNode"
|
||||||
"width": 200
|
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"history": [],
|
"history": [],
|
||||||
|
"memory": [],
|
||||||
"messages": [],
|
"messages": [],
|
||||||
"path": [
|
"path": [],
|
||||||
[
|
"retrieval": [],
|
||||||
"begin"
|
"task_id": "dbb4ed366e8611f09690a55a6daec4ef"
|
||||||
],
|
|
||||||
[]
|
|
||||||
],
|
|
||||||
"reference": []
|
|
||||||
},
|
},
|
||||||
"id": "2581031eb7a311efb5200242ac120005",
|
"id": "0b02fe80780e11f084adcfdc3ed1d902",
|
||||||
"message": [
|
"message": [
|
||||||
{
|
{
|
||||||
"content": "Hi! I'm your smart assistant. What can I do for you?",
|
"content": "Hi! I'm your assistant, what can I do for you?",
|
||||||
"role": "assistant"
|
"role": "assistant"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": "agent",
|
"source": "agent",
|
||||||
"user_id": "69736c5e723611efb51b0242ac120007"
|
"user_id": "c3fb861af27a11efa69751e139332ced"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|||||||
@ -105,4 +105,5 @@ REMEMBER:
|
|||||||
- Cite FACTS, not opinions or transitions
|
- Cite FACTS, not opinions or transitions
|
||||||
- Each citation supports the ENTIRE sentence
|
- Each citation supports the ENTIRE sentence
|
||||||
- When in doubt, ask: "Would a fact-checker need to verify this?"
|
- When in doubt, ask: "Would a fact-checker need to verify this?"
|
||||||
- Place citations at sentence end, before punctuation
|
- Place citations at sentence end, before punctuation
|
||||||
|
- Format likes this is FORBIDDEN: [ID:0, ID:5, ID:...]. It MUST be seperated like, [ID:0][ID:5]...
|
||||||
|
|||||||
@ -197,7 +197,7 @@ def question_proposal(chat_mdl, content, topn=3):
|
|||||||
def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
|
def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
|
|
||||||
if not chat_mdl:
|
if not chat_mdl:
|
||||||
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
||||||
@ -231,7 +231,7 @@ def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_
|
|||||||
def cross_languages(tenant_id, llm_id, query, languages=[]):
|
def cross_languages(tenant_id, llm_id, query, languages=[]):
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
|
|
||||||
if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||||
|
|||||||
@ -191,7 +191,6 @@ class RAGFlowS3:
|
|||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
return
|
return
|
||||||
|
|
||||||
@use_prefix_path
|
|
||||||
@use_default_bucket
|
@use_default_bucket
|
||||||
def rm_bucket(self, bucket, *args, **kwargs):
|
def rm_bucket(self, bucket, *args, **kwargs):
|
||||||
for conn in self.conn:
|
for conn in self.conn:
|
||||||
|
|||||||
46
web/src/components/originui/password-input.tsx
Normal file
46
web/src/components/originui/password-input.tsx
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
// https://originui.com/r/comp-23.json
|
||||||
|
|
||||||
|
'use client';
|
||||||
|
|
||||||
|
import { EyeIcon, EyeOffIcon } from 'lucide-react';
|
||||||
|
import React, { useId, useState } from 'react';
|
||||||
|
import { Input, InputProps } from '../ui/input';
|
||||||
|
|
||||||
|
export default React.forwardRef<HTMLInputElement, InputProps>(
|
||||||
|
function PasswordInput({ ...props }, ref) {
|
||||||
|
const id = useId();
|
||||||
|
const [isVisible, setIsVisible] = useState<boolean>(false);
|
||||||
|
|
||||||
|
const toggleVisibility = () => setIsVisible((prevState) => !prevState);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="*:not-first:mt-2">
|
||||||
|
{/* <Label htmlFor={id}>Show/hide password input</Label> */}
|
||||||
|
<div className="relative">
|
||||||
|
<Input
|
||||||
|
id={id}
|
||||||
|
className="pe-9"
|
||||||
|
placeholder="Password"
|
||||||
|
type={isVisible ? 'text' : 'password'}
|
||||||
|
ref={ref}
|
||||||
|
{...props}
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
className="text-muted-foreground/80 hover:text-foreground focus-visible:border-ring focus-visible:ring-ring/50 absolute inset-y-0 end-0 flex h-full w-9 items-center justify-center rounded-e-md transition-[color,box-shadow] outline-none focus:z-10 focus-visible:ring-[3px] disabled:pointer-events-none disabled:cursor-not-allowed disabled:opacity-50"
|
||||||
|
type="button"
|
||||||
|
onClick={toggleVisibility}
|
||||||
|
aria-label={isVisible ? 'Hide password' : 'Show password'}
|
||||||
|
aria-pressed={isVisible}
|
||||||
|
aria-controls="password"
|
||||||
|
>
|
||||||
|
{isVisible ? (
|
||||||
|
<EyeOffIcon size={16} aria-hidden="true" />
|
||||||
|
) : (
|
||||||
|
<EyeIcon size={16} aria-hidden="true" />
|
||||||
|
)}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
);
|
||||||
@ -2,7 +2,7 @@ import { PropsWithChildren } from 'react';
|
|||||||
|
|
||||||
export function PageHeader({ children }: PropsWithChildren) {
|
export function PageHeader({ children }: PropsWithChildren) {
|
||||||
return (
|
return (
|
||||||
<header className="flex justify-between items-center border-b bg-text-title-invert p-5">
|
<header className="flex justify-between items-center bg-text-title-invert p-5">
|
||||||
{children}
|
{children}
|
||||||
</header>
|
</header>
|
||||||
);
|
);
|
||||||
|
|||||||
@ -1,52 +0,0 @@
|
|||||||
import { Input } from '@/components/originui/input';
|
|
||||||
import { EyeIcon, EyeOffIcon } from 'lucide-react';
|
|
||||||
import { ChangeEvent, forwardRef, useId, useState } from 'react';
|
|
||||||
|
|
||||||
type PropType = {
|
|
||||||
name: string;
|
|
||||||
value: string;
|
|
||||||
onBlur: () => void;
|
|
||||||
onChange: (event: ChangeEvent<HTMLInputElement>) => void;
|
|
||||||
};
|
|
||||||
|
|
||||||
function PasswordInput(props: PropType) {
|
|
||||||
const id = useId();
|
|
||||||
const [isVisible, setIsVisible] = useState<boolean>(false);
|
|
||||||
|
|
||||||
const toggleVisibility = () => setIsVisible((prevState) => !prevState);
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="*:not-first:mt-2 w-full">
|
|
||||||
{/* <Label htmlFor={id}>Show/hide password input</Label> */}
|
|
||||||
<div className="relative">
|
|
||||||
<Input
|
|
||||||
autoComplete="off"
|
|
||||||
inputMode="numeric"
|
|
||||||
id={id}
|
|
||||||
className="pe-9"
|
|
||||||
placeholder=""
|
|
||||||
type={isVisible ? 'text' : 'password'}
|
|
||||||
value={props.value}
|
|
||||||
onBlur={props.onBlur}
|
|
||||||
onChange={(ev) => props.onChange(ev)}
|
|
||||||
/>
|
|
||||||
<button
|
|
||||||
className="text-muted-foreground/80 hover:text-foreground focus-visible:border-ring focus-visible:ring-ring/50 absolute inset-y-0 end-0 flex h-full w-9 items-center justify-center rounded-e-md transition-[color,box-shadow] outline-none focus:z-10 focus-visible:ring-[3px] disabled:pointer-events-none disabled:cursor-not-allowed disabled:opacity-50"
|
|
||||||
type="button"
|
|
||||||
onClick={toggleVisibility}
|
|
||||||
aria-label={isVisible ? 'Hide password' : 'Show password'}
|
|
||||||
aria-pressed={isVisible}
|
|
||||||
aria-controls="password"
|
|
||||||
>
|
|
||||||
{isVisible ? (
|
|
||||||
<EyeOffIcon size={16} aria-hidden="true" />
|
|
||||||
) : (
|
|
||||||
<EyeIcon size={16} aria-hidden="true" />
|
|
||||||
)}
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
export default forwardRef(PasswordInput);
|
|
||||||
51
web/src/components/tavily-form-field.tsx
Normal file
51
web/src/components/tavily-form-field.tsx
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import { useTranslate } from '@/hooks/common-hooks';
|
||||||
|
import { useFormContext } from 'react-hook-form';
|
||||||
|
import PasswordInput from './originui/password-input';
|
||||||
|
import {
|
||||||
|
FormControl,
|
||||||
|
FormDescription,
|
||||||
|
FormField,
|
||||||
|
FormItem,
|
||||||
|
FormLabel,
|
||||||
|
FormMessage,
|
||||||
|
} from './ui/form';
|
||||||
|
|
||||||
|
interface IProps {
|
||||||
|
name?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function TavilyFormField({
|
||||||
|
name = 'prompt_config.tavily_api_key',
|
||||||
|
}: IProps) {
|
||||||
|
const form = useFormContext();
|
||||||
|
const { t } = useTranslate('chat');
|
||||||
|
|
||||||
|
return (
|
||||||
|
<FormField
|
||||||
|
control={form.control}
|
||||||
|
name={name}
|
||||||
|
render={({ field }) => (
|
||||||
|
<FormItem>
|
||||||
|
<FormLabel tooltip={t('tavilyApiKeyTip')}>Tavily API Key</FormLabel>
|
||||||
|
<FormControl>
|
||||||
|
<PasswordInput
|
||||||
|
{...field}
|
||||||
|
placeholder={t('tavilyApiKeyMessage')}
|
||||||
|
autoComplete="new-password"
|
||||||
|
></PasswordInput>
|
||||||
|
</FormControl>
|
||||||
|
<FormDescription>
|
||||||
|
<a
|
||||||
|
href="https://app.tavily.com/home"
|
||||||
|
target={'_blank'}
|
||||||
|
rel="noreferrer"
|
||||||
|
>
|
||||||
|
{t('tavilyApiKeyHelp')}
|
||||||
|
</a>
|
||||||
|
</FormDescription>
|
||||||
|
<FormMessage />
|
||||||
|
</FormItem>
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
@ -3,6 +3,7 @@
|
|||||||
import { FileUploader } from '@/components/file-uploader';
|
import { FileUploader } from '@/components/file-uploader';
|
||||||
import { KnowledgeBaseFormField } from '@/components/knowledge-base-item';
|
import { KnowledgeBaseFormField } from '@/components/knowledge-base-item';
|
||||||
import { SwitchFormField } from '@/components/switch-fom-field';
|
import { SwitchFormField } from '@/components/switch-fom-field';
|
||||||
|
import { TavilyFormField } from '@/components/tavily-form-field';
|
||||||
import {
|
import {
|
||||||
FormControl,
|
FormControl,
|
||||||
FormField,
|
FormField,
|
||||||
@ -105,6 +106,7 @@ export default function ChatBasicSetting() {
|
|||||||
name={'prompt_config.tts'}
|
name={'prompt_config.tts'}
|
||||||
label={t('tts')}
|
label={t('tts')}
|
||||||
></SwitchFormField>
|
></SwitchFormField>
|
||||||
|
<TavilyFormField></TavilyFormField>
|
||||||
<KnowledgeBaseFormField></KnowledgeBaseFormField>
|
<KnowledgeBaseFormField></KnowledgeBaseFormField>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@ -68,8 +68,8 @@ export function ChatSettings({ switchSettingVisible }: ChatSettingsProps) {
|
|||||||
}, [data, form]);
|
}, [data, form]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<section className="p-5 w-[440px] ">
|
<section className="p-5 w-[440px] border-l">
|
||||||
<div className="flex justify-between items-center text-base">
|
<div className="flex justify-between items-center text-base pb-2">
|
||||||
Chat Settings
|
Chat Settings
|
||||||
<X className="size-4 cursor-pointer" onClick={switchSettingVisible} />
|
<X className="size-4 cursor-pointer" onClick={switchSettingVisible} />
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@ -24,6 +24,7 @@ export function useChatSettingSchema() {
|
|||||||
optional: z.boolean(),
|
optional: z.boolean(),
|
||||||
}),
|
}),
|
||||||
),
|
),
|
||||||
|
tavily_api_key: z.string().optional(),
|
||||||
});
|
});
|
||||||
|
|
||||||
const formSchema = z.object({
|
const formSchema = z.object({
|
||||||
|
|||||||
155
web/src/pages/next-chats/chat/chat-box/multiple-chat-box.tsx
Normal file
155
web/src/pages/next-chats/chat/chat-box/multiple-chat-box.tsx
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import { NextMessageInput } from '@/components/message-input/next';
|
||||||
|
import MessageItem from '@/components/message-item';
|
||||||
|
import { Button } from '@/components/ui/button';
|
||||||
|
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
|
||||||
|
import { MessageType } from '@/constants/chat';
|
||||||
|
import {
|
||||||
|
useFetchConversation,
|
||||||
|
useFetchDialog,
|
||||||
|
useGetChatSearchParams,
|
||||||
|
} from '@/hooks/use-chat-request';
|
||||||
|
import { useFetchUserInfo } from '@/hooks/user-setting-hooks';
|
||||||
|
import { buildMessageUuidWithRole } from '@/utils/chat';
|
||||||
|
import { Trash2 } from 'lucide-react';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import {
|
||||||
|
useGetSendButtonDisabled,
|
||||||
|
useSendButtonDisabled,
|
||||||
|
} from '../../hooks/use-button-disabled';
|
||||||
|
import { useCreateConversationBeforeUploadDocument } from '../../hooks/use-create-conversation';
|
||||||
|
import { useSendMessage } from '../../hooks/use-send-chat-message';
|
||||||
|
import { buildMessageItemReference } from '../../utils';
|
||||||
|
import { useAddChatBox } from '../use-add-box';
|
||||||
|
|
||||||
|
type MultipleChatBoxProps = {
|
||||||
|
controller: AbortController;
|
||||||
|
chatBoxIds: string[];
|
||||||
|
} & Pick<ReturnType<typeof useAddChatBox>, 'removeChatBox'>;
|
||||||
|
|
||||||
|
type ChatCardProps = { id: string } & Pick<
|
||||||
|
MultipleChatBoxProps,
|
||||||
|
'controller' | 'removeChatBox'
|
||||||
|
>;
|
||||||
|
|
||||||
|
function ChatCard({ controller, removeChatBox, id }: ChatCardProps) {
|
||||||
|
const {
|
||||||
|
value,
|
||||||
|
// scrollRef,
|
||||||
|
messageContainerRef,
|
||||||
|
sendLoading,
|
||||||
|
derivedMessages,
|
||||||
|
handleInputChange,
|
||||||
|
handlePressEnter,
|
||||||
|
regenerateMessage,
|
||||||
|
removeMessageById,
|
||||||
|
stopOutputMessage,
|
||||||
|
} = useSendMessage(controller);
|
||||||
|
|
||||||
|
const { data: userInfo } = useFetchUserInfo();
|
||||||
|
const { data: currentDialog } = useFetchDialog();
|
||||||
|
const { data: conversation } = useFetchConversation();
|
||||||
|
|
||||||
|
const handleRemoveChatBox = useCallback(() => {
|
||||||
|
removeChatBox(id);
|
||||||
|
}, [id, removeChatBox]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Card className="bg-transparent border flex-1">
|
||||||
|
<CardHeader className="border-b px-5 py-3">
|
||||||
|
<CardTitle className="flex justify-between items-center">
|
||||||
|
<div>
|
||||||
|
<span className="text-base">Card Title</span>
|
||||||
|
<Button variant={'ghost'} className="ml-2">
|
||||||
|
GPT-4
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
<Button variant={'ghost'} onClick={handleRemoveChatBox}>
|
||||||
|
<Trash2 />
|
||||||
|
</Button>
|
||||||
|
</CardTitle>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent>
|
||||||
|
<div ref={messageContainerRef} className="flex-1 overflow-auto min-h-0">
|
||||||
|
<div className="w-full">
|
||||||
|
{derivedMessages?.map((message, i) => {
|
||||||
|
return (
|
||||||
|
<MessageItem
|
||||||
|
loading={
|
||||||
|
message.role === MessageType.Assistant &&
|
||||||
|
sendLoading &&
|
||||||
|
derivedMessages.length - 1 === i
|
||||||
|
}
|
||||||
|
key={buildMessageUuidWithRole(message)}
|
||||||
|
item={message}
|
||||||
|
nickname={userInfo.nickname}
|
||||||
|
avatar={userInfo.avatar}
|
||||||
|
avatarDialog={currentDialog.icon}
|
||||||
|
reference={buildMessageItemReference(
|
||||||
|
{
|
||||||
|
message: derivedMessages,
|
||||||
|
reference: conversation.reference,
|
||||||
|
},
|
||||||
|
message,
|
||||||
|
)}
|
||||||
|
// clickDocumentButton={clickDocumentButton}
|
||||||
|
index={i}
|
||||||
|
removeMessageById={removeMessageById}
|
||||||
|
regenerateMessage={regenerateMessage}
|
||||||
|
sendLoading={sendLoading}
|
||||||
|
></MessageItem>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
{/* <div ref={scrollRef} /> */}
|
||||||
|
</div>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function MultipleChatBox({
|
||||||
|
controller,
|
||||||
|
chatBoxIds,
|
||||||
|
removeChatBox,
|
||||||
|
}: MultipleChatBoxProps) {
|
||||||
|
const {
|
||||||
|
value,
|
||||||
|
sendLoading,
|
||||||
|
handleInputChange,
|
||||||
|
handlePressEnter,
|
||||||
|
stopOutputMessage,
|
||||||
|
} = useSendMessage(controller);
|
||||||
|
|
||||||
|
const { createConversationBeforeUploadDocument } =
|
||||||
|
useCreateConversationBeforeUploadDocument();
|
||||||
|
const { conversationId } = useGetChatSearchParams();
|
||||||
|
const disabled = useGetSendButtonDisabled();
|
||||||
|
const sendDisabled = useSendButtonDisabled(value);
|
||||||
|
return (
|
||||||
|
<section className="h-full flex flex-col">
|
||||||
|
<div className="flex gap-4 flex-1 px-5 pb-12">
|
||||||
|
{chatBoxIds.map((id) => (
|
||||||
|
<ChatCard
|
||||||
|
key={id}
|
||||||
|
controller={controller}
|
||||||
|
id={id}
|
||||||
|
removeChatBox={removeChatBox}
|
||||||
|
></ChatCard>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
<NextMessageInput
|
||||||
|
disabled={disabled}
|
||||||
|
sendDisabled={sendDisabled}
|
||||||
|
sendLoading={sendLoading}
|
||||||
|
value={value}
|
||||||
|
onInputChange={handleInputChange}
|
||||||
|
onPressEnter={handlePressEnter}
|
||||||
|
conversationId={conversationId}
|
||||||
|
createConversationBeforeUploadDocument={
|
||||||
|
createConversationBeforeUploadDocument
|
||||||
|
}
|
||||||
|
stopOutputMessage={stopOutputMessage}
|
||||||
|
/>
|
||||||
|
</section>
|
||||||
|
);
|
||||||
|
}
|
||||||
@ -11,16 +11,16 @@ import { buildMessageUuidWithRole } from '@/utils/chat';
|
|||||||
import {
|
import {
|
||||||
useGetSendButtonDisabled,
|
useGetSendButtonDisabled,
|
||||||
useSendButtonDisabled,
|
useSendButtonDisabled,
|
||||||
} from '../hooks/use-button-disabled';
|
} from '../../hooks/use-button-disabled';
|
||||||
import { useCreateConversationBeforeUploadDocument } from '../hooks/use-create-conversation';
|
import { useCreateConversationBeforeUploadDocument } from '../../hooks/use-create-conversation';
|
||||||
import { useSendMessage } from '../hooks/use-send-chat-message';
|
import { useSendMessage } from '../../hooks/use-send-chat-message';
|
||||||
import { buildMessageItemReference } from '../utils';
|
import { buildMessageItemReference } from '../../utils';
|
||||||
|
|
||||||
interface IProps {
|
interface IProps {
|
||||||
controller: AbortController;
|
controller: AbortController;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function ChatBox({ controller }: IProps) {
|
export function SingleChatBox({ controller }: IProps) {
|
||||||
const {
|
const {
|
||||||
value,
|
value,
|
||||||
// scrollRef,
|
// scrollRef,
|
||||||
@ -43,7 +43,7 @@ export function ChatBox({ controller }: IProps) {
|
|||||||
const sendDisabled = useSendButtonDisabled(value);
|
const sendDisabled = useSendButtonDisabled(value);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<section className="border-x flex flex-col p-5 flex-1 min-w-0">
|
<section className="flex flex-col p-5 h-full">
|
||||||
<div ref={messageContainerRef} className="flex-1 overflow-auto min-h-0">
|
<div ref={messageContainerRef} className="flex-1 overflow-auto min-h-0">
|
||||||
<div className="w-full">
|
<div className="w-full">
|
||||||
{derivedMessages?.map((message, i) => {
|
{derivedMessages?.map((message, i) => {
|
||||||
@ -7,14 +7,20 @@ import {
|
|||||||
BreadcrumbPage,
|
BreadcrumbPage,
|
||||||
BreadcrumbSeparator,
|
BreadcrumbSeparator,
|
||||||
} from '@/components/ui/breadcrumb';
|
} from '@/components/ui/breadcrumb';
|
||||||
|
import { Button } from '@/components/ui/button';
|
||||||
|
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
|
||||||
import { useSetModalState } from '@/hooks/common-hooks';
|
import { useSetModalState } from '@/hooks/common-hooks';
|
||||||
import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks';
|
import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks';
|
||||||
import { useFetchDialog } from '@/hooks/use-chat-request';
|
import { useFetchDialog } from '@/hooks/use-chat-request';
|
||||||
|
import { cn } from '@/lib/utils';
|
||||||
|
import { Plus } from 'lucide-react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useHandleClickConversationCard } from '../hooks/use-click-card';
|
import { useHandleClickConversationCard } from '../hooks/use-click-card';
|
||||||
import { ChatSettings } from './app-settings/chat-settings';
|
import { ChatSettings } from './app-settings/chat-settings';
|
||||||
import { ChatBox } from './chat-box';
|
import { MultipleChatBox } from './chat-box/multiple-chat-box';
|
||||||
|
import { SingleChatBox } from './chat-box/single-chat-box';
|
||||||
import { Sessions } from './sessions';
|
import { Sessions } from './sessions';
|
||||||
|
import { useAddChatBox } from './use-add-box';
|
||||||
|
|
||||||
export default function Chat() {
|
export default function Chat() {
|
||||||
const { navigateToChatList } = useNavigatePage();
|
const { navigateToChatList } = useNavigatePage();
|
||||||
@ -24,9 +30,16 @@ export default function Chat() {
|
|||||||
useHandleClickConversationCard();
|
useHandleClickConversationCard();
|
||||||
const { visible: settingVisible, switchVisible: switchSettingVisible } =
|
const { visible: settingVisible, switchVisible: switchSettingVisible } =
|
||||||
useSetModalState(true);
|
useSetModalState(true);
|
||||||
|
const {
|
||||||
|
removeChatBox,
|
||||||
|
addChatBox,
|
||||||
|
chatBoxIds,
|
||||||
|
hasSingleChatBox,
|
||||||
|
hasThreeChatBox,
|
||||||
|
} = useAddChatBox();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<section className="h-full flex flex-col">
|
<section className="h-full flex flex-col pr-5">
|
||||||
<PageHeader>
|
<PageHeader>
|
||||||
<Breadcrumb>
|
<Breadcrumb>
|
||||||
<BreadcrumbList>
|
<BreadcrumbList>
|
||||||
@ -43,18 +56,52 @@ export default function Chat() {
|
|||||||
</Breadcrumb>
|
</Breadcrumb>
|
||||||
</PageHeader>
|
</PageHeader>
|
||||||
<div className="flex flex-1 min-h-0">
|
<div className="flex flex-1 min-h-0">
|
||||||
<div className="flex flex-1 min-w-0">
|
<Sessions
|
||||||
<Sessions
|
handleConversationCardClick={handleConversationCardClick}
|
||||||
handleConversationCardClick={handleConversationCardClick}
|
switchSettingVisible={switchSettingVisible}
|
||||||
switchSettingVisible={switchSettingVisible}
|
></Sessions>
|
||||||
></Sessions>
|
|
||||||
<ChatBox controller={controller}></ChatBox>
|
<Card className="flex-1 min-w-0 bg-transparent border h-full">
|
||||||
</div>
|
<CardContent className="flex p-0 h-full">
|
||||||
{settingVisible && (
|
<Card className="flex flex-col flex-1 bg-transparent">
|
||||||
<ChatSettings
|
<CardHeader
|
||||||
switchSettingVisible={switchSettingVisible}
|
className={cn('p-5', { 'border-b': hasSingleChatBox })}
|
||||||
></ChatSettings>
|
>
|
||||||
)}
|
<CardTitle className="flex justify-between items-center">
|
||||||
|
<div className="text-base">
|
||||||
|
Card Title
|
||||||
|
<Button variant={'ghost'} className="ml-2">
|
||||||
|
GPT-4
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
<Button
|
||||||
|
variant={'ghost'}
|
||||||
|
onClick={addChatBox}
|
||||||
|
disabled={hasThreeChatBox}
|
||||||
|
>
|
||||||
|
<Plus></Plus> Multiple Models
|
||||||
|
</Button>
|
||||||
|
</CardTitle>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent className="flex-1 p-0">
|
||||||
|
{hasSingleChatBox ? (
|
||||||
|
<SingleChatBox controller={controller}></SingleChatBox>
|
||||||
|
) : (
|
||||||
|
<MultipleChatBox
|
||||||
|
chatBoxIds={chatBoxIds}
|
||||||
|
controller={controller}
|
||||||
|
removeChatBox={removeChatBox}
|
||||||
|
></MultipleChatBox>
|
||||||
|
)}
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
{settingVisible && (
|
||||||
|
<ChatSettings
|
||||||
|
switchSettingVisible={switchSettingVisible}
|
||||||
|
></ChatSettings>
|
||||||
|
)}
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
</div>
|
</div>
|
||||||
</section>
|
</section>
|
||||||
);
|
);
|
||||||
|
|||||||
26
web/src/pages/next-chats/chat/use-add-box.ts
Normal file
26
web/src/pages/next-chats/chat/use-add-box.ts
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import { useCallback, useState } from 'react';
|
||||||
|
import { v4 as uuid } from 'uuid';
|
||||||
|
|
||||||
|
export function useAddChatBox() {
|
||||||
|
const [ids, setIds] = useState<string[]>([uuid()]);
|
||||||
|
|
||||||
|
const hasSingleChatBox = ids.length === 1;
|
||||||
|
|
||||||
|
const hasThreeChatBox = ids.length === 3;
|
||||||
|
|
||||||
|
const addChatBox = useCallback(() => {
|
||||||
|
setIds((prev) => [...prev, uuid()]);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const removeChatBox = useCallback((id: string) => {
|
||||||
|
setIds((prev) => prev.filter((x) => x !== id));
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return {
|
||||||
|
chatBoxIds: ids,
|
||||||
|
hasSingleChatBox,
|
||||||
|
hasThreeChatBox,
|
||||||
|
addChatBox,
|
||||||
|
removeChatBox,
|
||||||
|
};
|
||||||
|
}
|
||||||
@ -1,4 +1,4 @@
|
|||||||
import PasswordInput from '@/components/password-input';
|
import PasswordInput from '@/components/originui/password-input';
|
||||||
import { Avatar, AvatarFallback, AvatarImage } from '@/components/ui/avatar';
|
import { Avatar, AvatarFallback, AvatarImage } from '@/components/ui/avatar';
|
||||||
import { Button } from '@/components/ui/button';
|
import { Button } from '@/components/ui/button';
|
||||||
import {
|
import {
|
||||||
|
|||||||
Reference in New Issue
Block a user