mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-04 03:25:30 +08:00
Compare commits
7 Commits
da5cef0686
...
5e8cd693a5
| Author | SHA1 | Date | |
|---|---|---|---|
| 5e8cd693a5 | |||
| 29f297b850 | |||
| 7235638607 | |||
| 00919fd599 | |||
| 43c0792ffd | |||
| 4b1b68c5fc | |||
| 3492f54c7a |
46
.github/ISSUE_TEMPLATE/agent_scenario_request.yml
vendored
Normal file
46
.github/ISSUE_TEMPLATE/agent_scenario_request.yml
vendored
Normal file
@ -0,0 +1,46 @@
|
||||
name: "❤️🔥ᴬᴳᴱᴺᵀ Agent scenario request"
|
||||
description: Propose a agent scenario request for RAGFlow.
|
||||
title: "[Agent Scenario Request]: "
|
||||
labels: ["❤️🔥ᴬᴳᴱᴺᵀ agent scenario"]
|
||||
body:
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Self Checks
|
||||
description: "Please check the following in order to be responded in time :)"
|
||||
options:
|
||||
- label: I have searched for existing issues [search for existing issues](https://github.com/infiniflow/ragflow/issues), including closed ones.
|
||||
required: true
|
||||
- label: I confirm that I am using English to submit this report ([Language Policy](https://github.com/infiniflow/ragflow/issues/5910)).
|
||||
required: true
|
||||
- label: Non-english title submitions will be closed directly ( 非英文标题的提交将会被直接关闭 ) ([Language Policy](https://github.com/infiniflow/ragflow/issues/5910)).
|
||||
required: true
|
||||
- label: "Please do not modify this template :) and fill in all the required fields."
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Is your feature request related to a scenario?
|
||||
description: |
|
||||
A clear and concise description of what the scenario is. Ex. I'm always frustrated when [...]
|
||||
render: Markdown
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Describe the feature you'd like
|
||||
description: A clear and concise description of what you want to happen.
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Documentation, adoption, use case
|
||||
description: If you can, explain some scenarios how users might use this, situations it would be helpful in. Any API designs, mockups, or diagrams are also helpful.
|
||||
render: Markdown
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Additional information
|
||||
description: |
|
||||
Add any other context or screenshots about the feature request here.
|
||||
validations:
|
||||
required: false
|
||||
@ -24,7 +24,8 @@ from typing import Any
|
||||
import json_repair
|
||||
|
||||
from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta
|
||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
from api.utils.api_utils import timeout
|
||||
from rag.prompts import message_fit_in
|
||||
|
||||
@ -24,7 +24,8 @@ from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from api.utils.api_utils import timeout
|
||||
from rag.prompts import message_fit_in, citation_prompt
|
||||
|
||||
@ -28,8 +28,8 @@ from api.db.db_models import APIToken
|
||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||
from api.db.services.dialog_service import DialogService, ask, chat
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle, TenantService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import UserTenantService, TenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
|
||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
from rag.app.tag import label_question
|
||||
|
||||
@ -18,7 +18,7 @@ from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api import settings
|
||||
|
||||
@ -17,7 +17,8 @@ import logging
|
||||
import json
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from api.db.services.llm_service import LLMFactoriesService, TenantLLMService, LLMService
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api import settings
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.db import StatusEnum, LLMType
|
||||
|
||||
@ -21,7 +21,7 @@ from api import settings
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_result, token_required
|
||||
|
||||
@ -32,7 +32,8 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.task_service import TaskService, queue_tasks
|
||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
|
||||
@ -16,10 +16,8 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
import tiktoken
|
||||
from flask import Response, jsonify, request
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from api.db import LLMType, StatusEnum
|
||||
from api.db.db_models import APIToken
|
||||
@ -29,7 +27,6 @@ from api.db.services.canvas_service import completion as agent_completion
|
||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
||||
from api.db.services.conversation_service import completion as rag_completion
|
||||
from api.db.services.dialog_service import DialogService, ask, chat
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.utils import get_uuid
|
||||
@ -69,11 +66,7 @@ def create(tenant_id, chat_id):
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
def create_agent_session(tenant_id, agent_id):
|
||||
req = request.json
|
||||
if not request.is_json:
|
||||
req = request.form
|
||||
files = request.files
|
||||
user_id = request.args.get("user_id", "")
|
||||
user_id = request.args.get("user_id", tenant_id)
|
||||
e, cvs = UserCanvasService.get_by_id(agent_id)
|
||||
if not e:
|
||||
return get_error_data_result("Agent not found.")
|
||||
@ -82,46 +75,21 @@ def create_agent_session(tenant_id, agent_id):
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
|
||||
canvas = Canvas(cvs.dsl, tenant_id)
|
||||
session_id=get_uuid()
|
||||
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
||||
canvas.reset()
|
||||
query = canvas.get_preset_param()
|
||||
if query:
|
||||
for ele in query:
|
||||
if not ele["optional"]:
|
||||
if ele["type"] == "file":
|
||||
if files is None or not files.get(ele["key"]):
|
||||
return get_error_data_result(f"`{ele['key']}` with type `{ele['type']}` is required")
|
||||
upload_file = files.get(ele["key"])
|
||||
file_content = FileService.parse_docs([upload_file], user_id)
|
||||
file_name = upload_file.filename
|
||||
ele["value"] = file_name + "\n" + file_content
|
||||
else:
|
||||
if req is None or not req.get(ele["key"]):
|
||||
return get_error_data_result(f"`{ele['key']}` with type `{ele['type']}` is required")
|
||||
ele["value"] = req[ele["key"]]
|
||||
else:
|
||||
if ele["type"] == "file":
|
||||
if files is not None and files.get(ele["key"]):
|
||||
upload_file = files.get(ele["key"])
|
||||
file_content = FileService.parse_docs([upload_file], user_id)
|
||||
file_name = upload_file.filename
|
||||
ele["value"] = file_name + "\n" + file_content
|
||||
else:
|
||||
if "value" in ele:
|
||||
ele.pop("value")
|
||||
else:
|
||||
if req is not None and req.get(ele["key"]):
|
||||
ele["value"] = req[ele["key"]]
|
||||
else:
|
||||
if "value" in ele:
|
||||
ele.pop("value")
|
||||
|
||||
for ans in canvas.run(stream=False):
|
||||
pass
|
||||
conv = {
|
||||
"id": session_id,
|
||||
"dialog_id": cvs.id,
|
||||
"user_id": user_id,
|
||||
"message": [],
|
||||
"source": "agent",
|
||||
"dsl": cvs.dsl
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
conv = {"id": get_uuid(), "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
|
||||
API4ConversationService.save(**conv)
|
||||
conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
|
||||
conv["agent_id"] = conv.pop("dialog_id")
|
||||
return get_result(data=conv)
|
||||
|
||||
|
||||
@ -28,7 +28,7 @@ from api.apps.auth import get_auth_client
|
||||
from api.db import FileType, UserTenantRole
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.llm_service import LLMService, TenantLLMService
|
||||
from api.db.services.llm_service import TenantLLMService, get_init_tenant_llm
|
||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||
from api.utils import (
|
||||
current_timestamp,
|
||||
@ -619,57 +619,8 @@ def user_register(user_id, user):
|
||||
"size": 0,
|
||||
"location": "",
|
||||
}
|
||||
tenant_llm = []
|
||||
|
||||
seen = set()
|
||||
factory_configs = []
|
||||
for factory_config in [
|
||||
settings.CHAT_CFG,
|
||||
settings.EMBEDDING_CFG,
|
||||
settings.ASR_CFG,
|
||||
settings.IMAGE2TEXT_CFG,
|
||||
settings.RERANK_CFG,
|
||||
]:
|
||||
factory_name = factory_config["factory"]
|
||||
if factory_name not in seen:
|
||||
seen.add(factory_name)
|
||||
factory_configs.append(factory_config)
|
||||
|
||||
for factory_config in factory_configs:
|
||||
for llm in LLMService.query(fid=factory_config["factory"]):
|
||||
tenant_llm.append(
|
||||
{
|
||||
"tenant_id": user_id,
|
||||
"llm_factory": factory_config["factory"],
|
||||
"llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": factory_config["api_key"],
|
||||
"api_base": factory_config["base_url"],
|
||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||
}
|
||||
)
|
||||
|
||||
if settings.LIGHTEN != 1:
|
||||
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
|
||||
tenant_llm.append(
|
||||
{
|
||||
"tenant_id": user_id,
|
||||
"llm_factory": fid,
|
||||
"llm_name": mdlnm,
|
||||
"model_type": "embedding",
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
"max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512,
|
||||
}
|
||||
)
|
||||
|
||||
unique = {}
|
||||
for item in tenant_llm:
|
||||
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
|
||||
if key not in unique:
|
||||
unique[key] = item
|
||||
tenant_llm = list(unique.values())
|
||||
tenant_llm = get_init_tenant_llm(user_id)
|
||||
|
||||
if not UserService.save(**user):
|
||||
return
|
||||
|
||||
@ -27,7 +27,8 @@ from api.db.services import UserService
|
||||
from api.db.services.canvas_service import CanvasTemplateService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMFactoriesService, LLMService, TenantLLMService, LLMBundle
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api import settings
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
@ -64,43 +65,7 @@ def init_superuser():
|
||||
"role": UserTenantRole.OWNER
|
||||
}
|
||||
|
||||
user_id = user_info
|
||||
tenant_llm = []
|
||||
|
||||
seen = set()
|
||||
factory_configs = []
|
||||
for factory_config in [
|
||||
settings.CHAT_CFG["factory"],
|
||||
settings.EMBEDDING_CFG["factory"],
|
||||
settings.ASR_CFG["factory"],
|
||||
settings.IMAGE2TEXT_CFG["factory"],
|
||||
settings.RERANK_CFG["factory"],
|
||||
]:
|
||||
factory_name = factory_config["factory"]
|
||||
if factory_name not in seen:
|
||||
seen.add(factory_name)
|
||||
factory_configs.append(factory_config)
|
||||
|
||||
for factory_config in factory_configs:
|
||||
for llm in LLMService.query(fid=factory_config["factory"]):
|
||||
tenant_llm.append(
|
||||
{
|
||||
"tenant_id": user_id,
|
||||
"llm_factory": factory_config["factory"],
|
||||
"llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": factory_config["api_key"],
|
||||
"api_base": factory_config["base_url"],
|
||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||
}
|
||||
)
|
||||
|
||||
unique = {}
|
||||
for item in tenant_llm:
|
||||
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
|
||||
if key not in unique:
|
||||
unique[key] = item
|
||||
tenant_llm = list(unique.values())
|
||||
tenant_llm = get_init_tenant_llm(user_info["id"])
|
||||
|
||||
if not UserService.save(**user_info):
|
||||
logging.error("can't init admin.")
|
||||
|
||||
@ -33,7 +33,8 @@ from api.db.services.common_service import CommonService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.llm_service import LLMBundle, TenantLLMService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.utils import current_timestamp, datetime_format
|
||||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.app.tag import label_question
|
||||
@ -365,8 +366,12 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if dialog.meta_data_filter.get("method") == "auto":
|
||||
filters = gen_meta_filter(chat_mdl, metas, questions[-1])
|
||||
attachments.extend(meta_filter(metas, filters))
|
||||
if not attachments:
|
||||
attachments = None
|
||||
elif dialog.meta_data_filter.get("method") == "manual":
|
||||
attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"]))
|
||||
if not attachments:
|
||||
attachments = None
|
||||
|
||||
if prompt_config.get("keyword", False):
|
||||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
||||
@ -375,17 +380,16 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
thought = ""
|
||||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||||
knowledges = []
|
||||
|
||||
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
||||
knowledges = []
|
||||
else:
|
||||
if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]:
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
knowledges = []
|
||||
if prompt_config.get("reasoning", False):
|
||||
reasoner = DeepResearcher(
|
||||
chat_mdl,
|
||||
prompt_config,
|
||||
partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3),
|
||||
partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3, doc_ids=attachments),
|
||||
)
|
||||
|
||||
for think in reasoner.thinking(kbinfos, " ".join(questions)):
|
||||
|
||||
@ -18,246 +18,73 @@ import logging
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Generator
|
||||
|
||||
from langfuse import Langfuse
|
||||
|
||||
from api import settings
|
||||
from api.db import LLMType
|
||||
from api.db.db_models import DB, LLM, LLMFactories, TenantLLM
|
||||
from api.db.db_models import LLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.user_service import TenantService
|
||||
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
|
||||
|
||||
|
||||
class LLMFactoriesService(CommonService):
|
||||
model = LLMFactories
|
||||
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
||||
|
||||
|
||||
class LLMService(CommonService):
|
||||
model = LLM
|
||||
|
||||
|
||||
class TenantLLMService(CommonService):
|
||||
model = TenantLLM
|
||||
def get_init_tenant_llm(user_id):
|
||||
from api import settings
|
||||
tenant_llm = []
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_api_key(cls, tenant_id, model_name):
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
||||
if not fid:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
|
||||
else:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||
seen = set()
|
||||
factory_configs = []
|
||||
for factory_config in [
|
||||
settings.CHAT_CFG,
|
||||
settings.EMBEDDING_CFG,
|
||||
settings.ASR_CFG,
|
||||
settings.IMAGE2TEXT_CFG,
|
||||
settings.RERANK_CFG,
|
||||
]:
|
||||
factory_name = factory_config["factory"]
|
||||
if factory_name not in seen:
|
||||
seen.add(factory_name)
|
||||
factory_configs.append(factory_config)
|
||||
|
||||
if (not objs) and fid:
|
||||
if fid == "LocalAI":
|
||||
mdlnm += "___LocalAI"
|
||||
elif fid == "HuggingFace":
|
||||
mdlnm += "___HuggingFace"
|
||||
elif fid == "OpenAI-API-Compatible":
|
||||
mdlnm += "___OpenAI-API"
|
||||
elif fid == "VLLM":
|
||||
mdlnm += "___VLLM"
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||
if not objs:
|
||||
return
|
||||
return objs[0]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_my_llms(cls, tenant_id):
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||
|
||||
return list(objs)
|
||||
|
||||
@staticmethod
|
||||
def split_model_name_and_factory(model_name):
|
||||
arr = model_name.split("@")
|
||||
if len(arr) < 2:
|
||||
return model_name, None
|
||||
if len(arr) > 2:
|
||||
return "@".join(arr[0:-1]), arr[-1]
|
||||
|
||||
# model name must be xxx@yyy
|
||||
try:
|
||||
model_factories = settings.FACTORY_LLM_INFOS
|
||||
model_providers = set([f["name"] for f in model_factories])
|
||||
if arr[-1] not in model_providers:
|
||||
return model_name, None
|
||||
return arr[0], arr[-1]
|
||||
except Exception as e:
|
||||
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
|
||||
return model_name, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
raise LookupError("Tenant not found")
|
||||
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
mdlnm = tenant.embd_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.SPEECH2TEXT.value:
|
||||
mdlnm = tenant.asr_id
|
||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||
mdlnm = tenant.img2txt_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.CHAT.value:
|
||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.RERANK:
|
||||
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.TTS:
|
||||
mdlnm = tenant.tts_id if not llm_name else llm_name
|
||||
else:
|
||||
assert False, "LLM type error"
|
||||
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
if not model_config: # for some cases seems fid mismatch
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
if model_config:
|
||||
model_config = model_config.to_dict()
|
||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||
if not llm and fid: # for some cases seems fid mismatch
|
||||
llm = LLMService.query(llm_name=mdlnm)
|
||||
if llm:
|
||||
model_config["is_tools"] = llm[0].is_tools
|
||||
if not model_config:
|
||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
||||
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
||||
if not model_config:
|
||||
if mdlnm == "flag-embedding":
|
||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
|
||||
else:
|
||||
if not mdlnm:
|
||||
raise LookupError(f"Type of {llm_type} model is not set.")
|
||||
raise LookupError("Model({}) not authorized".format(mdlnm))
|
||||
return model_config
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
kwargs.update({"provider": model_config["llm_factory"]})
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
if model_config["llm_factory"] not in EmbeddingModel:
|
||||
return
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.RERANK:
|
||||
if model_config["llm_factory"] not in RerankModel:
|
||||
return
|
||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||
if model_config["llm_factory"] not in CvModel:
|
||||
return
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.CHAT.value:
|
||||
if model_config["llm_factory"] not in ChatModel:
|
||||
return
|
||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.SPEECH2TEXT:
|
||||
if model_config["llm_factory"] not in Seq2txtModel:
|
||||
return
|
||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
||||
if llm_type == LLMType.TTS:
|
||||
if model_config["llm_factory"] not in TTSModel:
|
||||
return
|
||||
return TTSModel[model_config["llm_factory"]](
|
||||
model_config["api_key"],
|
||||
model_config["llm_name"],
|
||||
base_url=model_config["api_base"],
|
||||
for factory_config in factory_configs:
|
||||
for llm in LLMService.query(fid=factory_config["factory"]):
|
||||
tenant_llm.append(
|
||||
{
|
||||
"tenant_id": user_id,
|
||||
"llm_factory": factory_config["factory"],
|
||||
"llm_name": llm.llm_name,
|
||||
"model_type": llm.model_type,
|
||||
"api_key": factory_config["api_key"],
|
||||
"api_base": factory_config["base_url"],
|
||||
"max_tokens": llm.max_tokens if llm.max_tokens else 8192,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
logging.error(f"Tenant not found: {tenant_id}")
|
||||
return 0
|
||||
|
||||
llm_map = {
|
||||
LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name,
|
||||
LLMType.SPEECH2TEXT.value: tenant.asr_id,
|
||||
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
|
||||
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
||||
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
||||
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
|
||||
}
|
||||
|
||||
mdlnm = llm_map.get(llm_type)
|
||||
if mdlnm is None:
|
||||
logging.error(f"LLM type error: {llm_type}")
|
||||
return 0
|
||||
|
||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
|
||||
try:
|
||||
num = (
|
||||
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
||||
.execute()
|
||||
if settings.LIGHTEN != 1:
|
||||
for buildin_embedding_model in settings.BUILTIN_EMBEDDING_MODELS:
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(buildin_embedding_model)
|
||||
tenant_llm.append(
|
||||
{
|
||||
"tenant_id": user_id,
|
||||
"llm_factory": fid,
|
||||
"llm_name": mdlnm,
|
||||
"model_type": "embedding",
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
"max_tokens": 1024 if buildin_embedding_model == "BAAI/bge-large-zh-v1.5@BAAI" else 512,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
||||
return 0
|
||||
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_openai_models(cls):
|
||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||
return list(objs)
|
||||
|
||||
@staticmethod
|
||||
def llm_id2llm_type(llm_id: str) -> str | None:
|
||||
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||
llm_factories = settings.FACTORY_LLM_INFOS
|
||||
for llm_factory in llm_factories:
|
||||
for llm in llm_factory["llm"]:
|
||||
if llm_id == llm["llm_name"]:
|
||||
return llm["model_type"].split(",")[-1]
|
||||
|
||||
for llm in LLMService.query(llm_name=llm_id):
|
||||
return llm.model_type
|
||||
|
||||
llm = TenantLLMService.get_or_none(llm_name=llm_id)
|
||||
if llm:
|
||||
return llm.model_type
|
||||
for llm in TenantLLMService.query(llm_name=llm_id):
|
||||
return llm.model_type
|
||||
unique = {}
|
||||
for item in tenant_llm:
|
||||
key = (item["tenant_id"], item["llm_factory"], item["llm_name"])
|
||||
if key not in unique:
|
||||
unique[key] = item
|
||||
return list(unique.values())
|
||||
|
||||
|
||||
class LLMBundle:
|
||||
class LLMBundle(LLM4Tenant):
|
||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_type = llm_type
|
||||
self.llm_name = llm_name
|
||||
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
|
||||
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
self.max_length = model_config.get("max_tokens", 8192)
|
||||
|
||||
self.is_tools = model_config.get("is_tools", False)
|
||||
self.verbose_tool_use = kwargs.get("verbose_tool_use")
|
||||
|
||||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
||||
self.langfuse = None
|
||||
if langfuse_keys:
|
||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
||||
if langfuse.auth_check():
|
||||
self.langfuse = langfuse
|
||||
trace_id = self.langfuse.create_trace_id()
|
||||
self.trace_context = {"trace_id": trace_id}
|
||||
super().__init__(tenant_id, llm_type, llm_name, lang, **kwargs)
|
||||
|
||||
def bind_tools(self, toolcall_session, tools):
|
||||
if not self.is_tools:
|
||||
|
||||
252
api/db/services/tenant_llm_service.py
Normal file
252
api/db/services/tenant_llm_service.py
Normal file
@ -0,0 +1,252 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
from langfuse import Langfuse
|
||||
from api import settings
|
||||
from api.db import LLMType
|
||||
from api.db.db_models import DB, LLMFactories, TenantLLM
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.user_service import TenantService
|
||||
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
|
||||
|
||||
|
||||
class LLMFactoriesService(CommonService):
|
||||
model = LLMFactories
|
||||
|
||||
|
||||
class TenantLLMService(CommonService):
|
||||
model = TenantLLM
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_api_key(cls, tenant_id, model_name):
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
|
||||
if not fid:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
|
||||
else:
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||
|
||||
if (not objs) and fid:
|
||||
if fid == "LocalAI":
|
||||
mdlnm += "___LocalAI"
|
||||
elif fid == "HuggingFace":
|
||||
mdlnm += "___HuggingFace"
|
||||
elif fid == "OpenAI-API-Compatible":
|
||||
mdlnm += "___OpenAI-API"
|
||||
elif fid == "VLLM":
|
||||
mdlnm += "___VLLM"
|
||||
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
|
||||
if not objs:
|
||||
return
|
||||
return objs[0]
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_my_llms(cls, tenant_id):
|
||||
fields = [cls.model.llm_factory, LLMFactories.logo, LLMFactories.tags, cls.model.model_type, cls.model.llm_name, cls.model.used_tokens]
|
||||
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()
|
||||
|
||||
return list(objs)
|
||||
|
||||
@staticmethod
|
||||
def split_model_name_and_factory(model_name):
|
||||
arr = model_name.split("@")
|
||||
if len(arr) < 2:
|
||||
return model_name, None
|
||||
if len(arr) > 2:
|
||||
return "@".join(arr[0:-1]), arr[-1]
|
||||
|
||||
# model name must be xxx@yyy
|
||||
try:
|
||||
model_factories = settings.FACTORY_LLM_INFOS
|
||||
model_providers = set([f["name"] for f in model_factories])
|
||||
if arr[-1] not in model_providers:
|
||||
return model_name, None
|
||||
return arr[0], arr[-1]
|
||||
except Exception as e:
|
||||
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
|
||||
return model_name, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_model_config(cls, tenant_id, llm_type, llm_name=None):
|
||||
from api.db.services.llm_service import LLMService
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
raise LookupError("Tenant not found")
|
||||
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
mdlnm = tenant.embd_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.SPEECH2TEXT.value:
|
||||
mdlnm = tenant.asr_id
|
||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||
mdlnm = tenant.img2txt_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.CHAT.value:
|
||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.RERANK:
|
||||
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.TTS:
|
||||
mdlnm = tenant.tts_id if not llm_name else llm_name
|
||||
else:
|
||||
assert False, "LLM type error"
|
||||
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
if not model_config: # for some cases seems fid mismatch
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
if model_config:
|
||||
model_config = model_config.to_dict()
|
||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||
if not llm and fid: # for some cases seems fid mismatch
|
||||
llm = LLMService.query(llm_name=mdlnm)
|
||||
if llm:
|
||||
model_config["is_tools"] = llm[0].is_tools
|
||||
if not model_config:
|
||||
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
|
||||
llm = LLMService.query(llm_name=mdlnm) if not fid else LLMService.query(llm_name=mdlnm, fid=fid)
|
||||
if llm and llm[0].fid in ["Youdao", "FastEmbed", "BAAI"]:
|
||||
model_config = {"llm_factory": llm[0].fid, "api_key": "", "llm_name": mdlnm, "api_base": ""}
|
||||
if not model_config:
|
||||
if mdlnm == "flag-embedding":
|
||||
model_config = {"llm_factory": "Tongyi-Qianwen", "api_key": "", "llm_name": llm_name, "api_base": ""}
|
||||
else:
|
||||
if not mdlnm:
|
||||
raise LookupError(f"Type of {llm_type} model is not set.")
|
||||
raise LookupError("Model({}) not authorized".format(mdlnm))
|
||||
return model_config
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
kwargs.update({"provider": model_config["llm_factory"]})
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
if model_config["llm_factory"] not in EmbeddingModel:
|
||||
return
|
||||
return EmbeddingModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.RERANK:
|
||||
if model_config["llm_factory"] not in RerankModel:
|
||||
return
|
||||
return RerankModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
|
||||
|
||||
if llm_type == LLMType.IMAGE2TEXT.value:
|
||||
if model_config["llm_factory"] not in CvModel:
|
||||
return
|
||||
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.CHAT.value:
|
||||
if model_config["llm_factory"] not in ChatModel:
|
||||
return
|
||||
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
|
||||
|
||||
if llm_type == LLMType.SPEECH2TEXT:
|
||||
if model_config["llm_factory"] not in Seq2txtModel:
|
||||
return
|
||||
return Seq2txtModel[model_config["llm_factory"]](key=model_config["api_key"], model_name=model_config["llm_name"], lang=lang, base_url=model_config["api_base"])
|
||||
if llm_type == LLMType.TTS:
|
||||
if model_config["llm_factory"] not in TTSModel:
|
||||
return
|
||||
return TTSModel[model_config["llm_factory"]](
|
||||
model_config["api_key"],
|
||||
model_config["llm_name"],
|
||||
base_url=model_config["api_base"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
|
||||
e, tenant = TenantService.get_by_id(tenant_id)
|
||||
if not e:
|
||||
logging.error(f"Tenant not found: {tenant_id}")
|
||||
return 0
|
||||
|
||||
llm_map = {
|
||||
LLMType.EMBEDDING.value: tenant.embd_id if not llm_name else llm_name,
|
||||
LLMType.SPEECH2TEXT.value: tenant.asr_id,
|
||||
LLMType.IMAGE2TEXT.value: tenant.img2txt_id,
|
||||
LLMType.CHAT.value: tenant.llm_id if not llm_name else llm_name,
|
||||
LLMType.RERANK.value: tenant.rerank_id if not llm_name else llm_name,
|
||||
LLMType.TTS.value: tenant.tts_id if not llm_name else llm_name,
|
||||
}
|
||||
|
||||
mdlnm = llm_map.get(llm_type)
|
||||
if mdlnm is None:
|
||||
logging.error(f"LLM type error: {llm_type}")
|
||||
return 0
|
||||
|
||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)
|
||||
|
||||
try:
|
||||
num = (
|
||||
cls.model.update(used_tokens=cls.model.used_tokens + used_tokens)
|
||||
.where(cls.model.tenant_id == tenant_id, cls.model.llm_name == llm_name, cls.model.llm_factory == llm_factory if llm_factory else True)
|
||||
.execute()
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("TenantLLMService.increase_usage got exception,Failed to update used_tokens for tenant_id=%s, llm_name=%s", tenant_id, llm_name)
|
||||
return 0
|
||||
|
||||
return num
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_openai_models(cls):
|
||||
objs = cls.model.select().where((cls.model.llm_factory == "OpenAI"), ~(cls.model.llm_name == "text-embedding-3-small"), ~(cls.model.llm_name == "text-embedding-3-large")).dicts()
|
||||
return list(objs)
|
||||
|
||||
@staticmethod
|
||||
def llm_id2llm_type(llm_id: str) -> str | None:
|
||||
from api.db.services.llm_service import LLMService
|
||||
llm_id, *_ = TenantLLMService.split_model_name_and_factory(llm_id)
|
||||
llm_factories = settings.FACTORY_LLM_INFOS
|
||||
for llm_factory in llm_factories:
|
||||
for llm in llm_factory["llm"]:
|
||||
if llm_id == llm["llm_name"]:
|
||||
return llm["model_type"].split(",")[-1]
|
||||
|
||||
for llm in LLMService.query(llm_name=llm_id):
|
||||
return llm.model_type
|
||||
|
||||
llm = TenantLLMService.get_or_none(llm_name=llm_id)
|
||||
if llm:
|
||||
return llm.model_type
|
||||
for llm in TenantLLMService.query(llm_name=llm_id):
|
||||
return llm.model_type
|
||||
|
||||
|
||||
class LLM4Tenant:
|
||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_type = llm_type
|
||||
self.llm_name = llm_name
|
||||
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
|
||||
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
|
||||
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
|
||||
self.max_length = model_config.get("max_tokens", 8192)
|
||||
|
||||
self.is_tools = model_config.get("is_tools", False)
|
||||
self.verbose_tool_use = kwargs.get("verbose_tool_use")
|
||||
|
||||
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
|
||||
self.langfuse = None
|
||||
if langfuse_keys:
|
||||
langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host)
|
||||
if langfuse.auth_check():
|
||||
self.langfuse = langfuse
|
||||
trace_id = self.langfuse.create_trace_id()
|
||||
self.trace_context = {"trace_id": trace_id}
|
||||
@ -48,7 +48,8 @@ from werkzeug.http import HTTP_STATUS_CODES
|
||||
from api import settings
|
||||
from api.constants import REQUEST_MAX_WAIT_SEC, REQUEST_WAIT_SEC
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.llm_service import LLMService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.utils import CustomJSONEncoder, get_uuid, json_dumps
|
||||
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
|
||||
|
||||
|
||||
@ -2632,9 +2632,11 @@ data:{
|
||||
"document_name": "1.txt",
|
||||
"dataset_id": "8e83e57a884611ef9d760242ac120006",
|
||||
"image_id": "",
|
||||
"url": null,
|
||||
"similarity": 0.7,
|
||||
"vector_similarity": 0.0,
|
||||
"term_similarity": 1.0,
|
||||
"doc_type": [],
|
||||
"positions": [
|
||||
""
|
||||
]
|
||||
@ -2649,6 +2651,7 @@ data:{
|
||||
]
|
||||
},
|
||||
"prompt": "xxxxxxxxxxx",
|
||||
"created_at": 1755055623.6401553,
|
||||
"id": "a84c5dd4-97b4-4624-8c3b-974012c8000d",
|
||||
"session_id": "82b0ab2a9c1911ef9d870242ac120006"
|
||||
}
|
||||
@ -2681,7 +2684,7 @@ Creates a session with an agent.
|
||||
- Method: POST
|
||||
- URL: `/api/v1/agents/{agent_id}/sessions?user_id={user_id}`
|
||||
- Headers:
|
||||
- `'content-Type: application/json' or 'multipart/form-data'`
|
||||
- `'content-Type: application/json'
|
||||
- `'Authorization: Bearer <YOUR_API_KEY>'`
|
||||
- Body:
|
||||
- the required parameters:`str`
|
||||
@ -2701,29 +2704,6 @@ curl --request POST \
|
||||
}'
|
||||
```
|
||||
|
||||
If the **Begin** component in your agent takes required parameters:
|
||||
|
||||
```bash
|
||||
curl --request POST \
|
||||
--url http://{address}/api/v1/agents/{agent_id}/sessions \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer <YOUR_API_KEY>' \
|
||||
--data '{
|
||||
"lang":"Japanese",
|
||||
"file":"Who are you"
|
||||
}'
|
||||
```
|
||||
|
||||
If the **Begin** component in your agent takes required file parameters:
|
||||
|
||||
```bash
|
||||
curl --request POST \
|
||||
--url http://{address}/api/v1/agents/{agent_id}/sessions?user_id={user_id} \
|
||||
--header 'Content-Type: multipart/form-data' \
|
||||
--header 'Authorization: Bearer <YOUR_API_KEY>' \
|
||||
--form '<FILE_KEY>=@./test1.png'
|
||||
```
|
||||
|
||||
##### Request parameters
|
||||
|
||||
- `agent_id`: (*Path parameter*)
|
||||
@ -2739,101 +2719,190 @@ Success:
|
||||
{
|
||||
"code": 0,
|
||||
"data": {
|
||||
"agent_id": "b4a39922b76611efaa1a0242ac120006",
|
||||
"agent_id": "dbb4ed366e8611f09690a55a6daec4ef",
|
||||
"dsl": {
|
||||
"answer": [],
|
||||
"components": {
|
||||
"Answer:GreenReadersDrum": {
|
||||
"Message:EightyJobsAsk": {
|
||||
"downstream": [],
|
||||
"obj": {
|
||||
"component_name": "Answer",
|
||||
"inputs": [],
|
||||
"output": null,
|
||||
"params": {}
|
||||
"component_name": "Message",
|
||||
"params": {
|
||||
"content": [
|
||||
"{begin@var1}{begin@var2}"
|
||||
],
|
||||
"debug_inputs": {},
|
||||
"delay_after_error": 2.0,
|
||||
"description": "",
|
||||
"exception_default_value": null,
|
||||
"exception_goto": null,
|
||||
"exception_method": null,
|
||||
"inputs": {},
|
||||
"max_retries": 0,
|
||||
"message_history_window_size": 22,
|
||||
"outputs": {
|
||||
"content": {
|
||||
"type": "str",
|
||||
"value": null
|
||||
}
|
||||
},
|
||||
"stream": true
|
||||
}
|
||||
},
|
||||
"upstream": []
|
||||
"upstream": [
|
||||
"begin"
|
||||
]
|
||||
},
|
||||
"begin": {
|
||||
"downstream": [],
|
||||
"downstream": [
|
||||
"Message:EightyJobsAsk"
|
||||
],
|
||||
"obj": {
|
||||
"component_name": "Begin",
|
||||
"inputs": [],
|
||||
"output": {},
|
||||
"params": {}
|
||||
"params": {
|
||||
"debug_inputs": {},
|
||||
"delay_after_error": 2.0,
|
||||
"description": "",
|
||||
"enablePrologue": true,
|
||||
"enable_tips": true,
|
||||
"exception_default_value": null,
|
||||
"exception_goto": null,
|
||||
"exception_method": null,
|
||||
"inputs": {
|
||||
"var1": {
|
||||
"name": "var1",
|
||||
"optional": false,
|
||||
"options": [],
|
||||
"type": "line",
|
||||
"value": null
|
||||
},
|
||||
"var2": {
|
||||
"name": "var2",
|
||||
"optional": false,
|
||||
"options": [],
|
||||
"type": "line",
|
||||
"value": null
|
||||
}
|
||||
},
|
||||
"max_retries": 0,
|
||||
"message_history_window_size": 22,
|
||||
"mode": "conversational",
|
||||
"outputs": {},
|
||||
"prologue": "Hi! I'm your assistant, what can I do for you?",
|
||||
"tips": "Please fill up the form"
|
||||
}
|
||||
},
|
||||
"upstream": []
|
||||
}
|
||||
},
|
||||
"embed_id": "",
|
||||
"globals": {
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": [],
|
||||
"sys.query": "",
|
||||
"sys.user_id": ""
|
||||
},
|
||||
"graph": {
|
||||
"edges": [],
|
||||
"edges": [
|
||||
{
|
||||
"data": {
|
||||
"isHovered": false
|
||||
},
|
||||
"id": "xy-edge__beginstart-Message:EightyJobsAskend",
|
||||
"markerEnd": "logo",
|
||||
"source": "begin",
|
||||
"sourceHandle": "start",
|
||||
"style": {
|
||||
"stroke": "rgba(151, 154, 171, 1)",
|
||||
"strokeWidth": 1
|
||||
},
|
||||
"target": "Message:EightyJobsAsk",
|
||||
"targetHandle": "end",
|
||||
"type": "buttonEdge",
|
||||
"zIndex": 1001
|
||||
}
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"form": {
|
||||
"enablePrologue": true,
|
||||
"inputs": {
|
||||
"var1": {
|
||||
"name": "var1",
|
||||
"optional": false,
|
||||
"options": [],
|
||||
"type": "line"
|
||||
},
|
||||
"var2": {
|
||||
"name": "var2",
|
||||
"optional": false,
|
||||
"options": [],
|
||||
"type": "line"
|
||||
}
|
||||
},
|
||||
"mode": "conversational",
|
||||
"prologue": "Hi! I'm your assistant, what can I do for you?"
|
||||
},
|
||||
"label": "Begin",
|
||||
"name": "begin"
|
||||
},
|
||||
"dragging": false,
|
||||
"height": 44,
|
||||
"id": "begin",
|
||||
"position": {
|
||||
"x": 53.25688640427177,
|
||||
"y": 198.37155679786412
|
||||
"measured": {
|
||||
"height": 112,
|
||||
"width": 200
|
||||
},
|
||||
"positionAbsolute": {
|
||||
"x": 53.25688640427177,
|
||||
"y": 198.37155679786412
|
||||
"position": {
|
||||
"x": 270.64098070942583,
|
||||
"y": -56.320928437811176
|
||||
},
|
||||
"selected": false,
|
||||
"sourcePosition": "left",
|
||||
"targetPosition": "right",
|
||||
"type": "beginNode",
|
||||
"width": 200
|
||||
"type": "beginNode"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"form": {},
|
||||
"label": "Answer",
|
||||
"name": "dialog_0"
|
||||
"form": {
|
||||
"content": [
|
||||
"{begin@var1}{begin@var2}"
|
||||
]
|
||||
},
|
||||
"label": "Message",
|
||||
"name": "Message_0"
|
||||
},
|
||||
"dragging": false,
|
||||
"height": 44,
|
||||
"id": "Answer:GreenReadersDrum",
|
||||
"id": "Message:EightyJobsAsk",
|
||||
"measured": {
|
||||
"height": 57,
|
||||
"width": 200
|
||||
},
|
||||
"position": {
|
||||
"x": 360.43473114516974,
|
||||
"y": 207.29298425089348
|
||||
"x": 279.5,
|
||||
"y": 190
|
||||
},
|
||||
"positionAbsolute": {
|
||||
"x": 360.43473114516974,
|
||||
"y": 207.29298425089348
|
||||
},
|
||||
"selected": false,
|
||||
"selected": true,
|
||||
"sourcePosition": "right",
|
||||
"targetPosition": "left",
|
||||
"type": "logicNode",
|
||||
"width": 200
|
||||
"type": "messageNode"
|
||||
}
|
||||
]
|
||||
},
|
||||
"history": [],
|
||||
"memory": [],
|
||||
"messages": [],
|
||||
"path": [
|
||||
[
|
||||
"begin"
|
||||
],
|
||||
[]
|
||||
],
|
||||
"reference": []
|
||||
"path": [],
|
||||
"retrieval": [],
|
||||
"task_id": "dbb4ed366e8611f09690a55a6daec4ef"
|
||||
},
|
||||
"id": "2581031eb7a311efb5200242ac120005",
|
||||
"id": "0b02fe80780e11f084adcfdc3ed1d902",
|
||||
"message": [
|
||||
{
|
||||
"content": "Hi! I'm your smart assistant. What can I do for you?",
|
||||
"content": "Hi! I'm your assistant, what can I do for you?",
|
||||
"role": "assistant"
|
||||
}
|
||||
],
|
||||
"source": "agent",
|
||||
"user_id": "69736c5e723611efb51b0242ac120007"
|
||||
"user_id": "c3fb861af27a11efa69751e139332ced"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@ -105,4 +105,5 @@ REMEMBER:
|
||||
- Cite FACTS, not opinions or transitions
|
||||
- Each citation supports the ENTIRE sentence
|
||||
- When in doubt, ask: "Would a fact-checker need to verify this?"
|
||||
- Place citations at sentence end, before punctuation
|
||||
- Place citations at sentence end, before punctuation
|
||||
- Format likes this is FORBIDDEN: [ID:0, ID:5, ID:...]. It MUST be seperated like, [ID:0][ID:5]...
|
||||
|
||||
@ -197,7 +197,7 @@ def question_proposal(chat_mdl, content, topn=3):
|
||||
def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
|
||||
if not chat_mdl:
|
||||
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
||||
@ -231,7 +231,7 @@ def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_
|
||||
def cross_languages(tenant_id, llm_id, query, languages=[]):
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
|
||||
if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||
|
||||
@ -191,7 +191,6 @@ class RAGFlowS3:
|
||||
time.sleep(1)
|
||||
return
|
||||
|
||||
@use_prefix_path
|
||||
@use_default_bucket
|
||||
def rm_bucket(self, bucket, *args, **kwargs):
|
||||
for conn in self.conn:
|
||||
|
||||
46
web/src/components/originui/password-input.tsx
Normal file
46
web/src/components/originui/password-input.tsx
Normal file
@ -0,0 +1,46 @@
|
||||
// https://originui.com/r/comp-23.json
|
||||
|
||||
'use client';
|
||||
|
||||
import { EyeIcon, EyeOffIcon } from 'lucide-react';
|
||||
import React, { useId, useState } from 'react';
|
||||
import { Input, InputProps } from '../ui/input';
|
||||
|
||||
export default React.forwardRef<HTMLInputElement, InputProps>(
|
||||
function PasswordInput({ ...props }, ref) {
|
||||
const id = useId();
|
||||
const [isVisible, setIsVisible] = useState<boolean>(false);
|
||||
|
||||
const toggleVisibility = () => setIsVisible((prevState) => !prevState);
|
||||
|
||||
return (
|
||||
<div className="*:not-first:mt-2">
|
||||
{/* <Label htmlFor={id}>Show/hide password input</Label> */}
|
||||
<div className="relative">
|
||||
<Input
|
||||
id={id}
|
||||
className="pe-9"
|
||||
placeholder="Password"
|
||||
type={isVisible ? 'text' : 'password'}
|
||||
ref={ref}
|
||||
{...props}
|
||||
/>
|
||||
<button
|
||||
className="text-muted-foreground/80 hover:text-foreground focus-visible:border-ring focus-visible:ring-ring/50 absolute inset-y-0 end-0 flex h-full w-9 items-center justify-center rounded-e-md transition-[color,box-shadow] outline-none focus:z-10 focus-visible:ring-[3px] disabled:pointer-events-none disabled:cursor-not-allowed disabled:opacity-50"
|
||||
type="button"
|
||||
onClick={toggleVisibility}
|
||||
aria-label={isVisible ? 'Hide password' : 'Show password'}
|
||||
aria-pressed={isVisible}
|
||||
aria-controls="password"
|
||||
>
|
||||
{isVisible ? (
|
||||
<EyeOffIcon size={16} aria-hidden="true" />
|
||||
) : (
|
||||
<EyeIcon size={16} aria-hidden="true" />
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
},
|
||||
);
|
||||
@ -2,7 +2,7 @@ import { PropsWithChildren } from 'react';
|
||||
|
||||
export function PageHeader({ children }: PropsWithChildren) {
|
||||
return (
|
||||
<header className="flex justify-between items-center border-b bg-text-title-invert p-5">
|
||||
<header className="flex justify-between items-center bg-text-title-invert p-5">
|
||||
{children}
|
||||
</header>
|
||||
);
|
||||
|
||||
@ -1,52 +0,0 @@
|
||||
import { Input } from '@/components/originui/input';
|
||||
import { EyeIcon, EyeOffIcon } from 'lucide-react';
|
||||
import { ChangeEvent, forwardRef, useId, useState } from 'react';
|
||||
|
||||
type PropType = {
|
||||
name: string;
|
||||
value: string;
|
||||
onBlur: () => void;
|
||||
onChange: (event: ChangeEvent<HTMLInputElement>) => void;
|
||||
};
|
||||
|
||||
function PasswordInput(props: PropType) {
|
||||
const id = useId();
|
||||
const [isVisible, setIsVisible] = useState<boolean>(false);
|
||||
|
||||
const toggleVisibility = () => setIsVisible((prevState) => !prevState);
|
||||
|
||||
return (
|
||||
<div className="*:not-first:mt-2 w-full">
|
||||
{/* <Label htmlFor={id}>Show/hide password input</Label> */}
|
||||
<div className="relative">
|
||||
<Input
|
||||
autoComplete="off"
|
||||
inputMode="numeric"
|
||||
id={id}
|
||||
className="pe-9"
|
||||
placeholder=""
|
||||
type={isVisible ? 'text' : 'password'}
|
||||
value={props.value}
|
||||
onBlur={props.onBlur}
|
||||
onChange={(ev) => props.onChange(ev)}
|
||||
/>
|
||||
<button
|
||||
className="text-muted-foreground/80 hover:text-foreground focus-visible:border-ring focus-visible:ring-ring/50 absolute inset-y-0 end-0 flex h-full w-9 items-center justify-center rounded-e-md transition-[color,box-shadow] outline-none focus:z-10 focus-visible:ring-[3px] disabled:pointer-events-none disabled:cursor-not-allowed disabled:opacity-50"
|
||||
type="button"
|
||||
onClick={toggleVisibility}
|
||||
aria-label={isVisible ? 'Hide password' : 'Show password'}
|
||||
aria-pressed={isVisible}
|
||||
aria-controls="password"
|
||||
>
|
||||
{isVisible ? (
|
||||
<EyeOffIcon size={16} aria-hidden="true" />
|
||||
) : (
|
||||
<EyeIcon size={16} aria-hidden="true" />
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default forwardRef(PasswordInput);
|
||||
51
web/src/components/tavily-form-field.tsx
Normal file
51
web/src/components/tavily-form-field.tsx
Normal file
@ -0,0 +1,51 @@
|
||||
import { useTranslate } from '@/hooks/common-hooks';
|
||||
import { useFormContext } from 'react-hook-form';
|
||||
import PasswordInput from './originui/password-input';
|
||||
import {
|
||||
FormControl,
|
||||
FormDescription,
|
||||
FormField,
|
||||
FormItem,
|
||||
FormLabel,
|
||||
FormMessage,
|
||||
} from './ui/form';
|
||||
|
||||
interface IProps {
|
||||
name?: string;
|
||||
}
|
||||
|
||||
export function TavilyFormField({
|
||||
name = 'prompt_config.tavily_api_key',
|
||||
}: IProps) {
|
||||
const form = useFormContext();
|
||||
const { t } = useTranslate('chat');
|
||||
|
||||
return (
|
||||
<FormField
|
||||
control={form.control}
|
||||
name={name}
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel tooltip={t('tavilyApiKeyTip')}>Tavily API Key</FormLabel>
|
||||
<FormControl>
|
||||
<PasswordInput
|
||||
{...field}
|
||||
placeholder={t('tavilyApiKeyMessage')}
|
||||
autoComplete="new-password"
|
||||
></PasswordInput>
|
||||
</FormControl>
|
||||
<FormDescription>
|
||||
<a
|
||||
href="https://app.tavily.com/home"
|
||||
target={'_blank'}
|
||||
rel="noreferrer"
|
||||
>
|
||||
{t('tavilyApiKeyHelp')}
|
||||
</a>
|
||||
</FormDescription>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@ -3,6 +3,7 @@
|
||||
import { FileUploader } from '@/components/file-uploader';
|
||||
import { KnowledgeBaseFormField } from '@/components/knowledge-base-item';
|
||||
import { SwitchFormField } from '@/components/switch-fom-field';
|
||||
import { TavilyFormField } from '@/components/tavily-form-field';
|
||||
import {
|
||||
FormControl,
|
||||
FormField,
|
||||
@ -105,6 +106,7 @@ export default function ChatBasicSetting() {
|
||||
name={'prompt_config.tts'}
|
||||
label={t('tts')}
|
||||
></SwitchFormField>
|
||||
<TavilyFormField></TavilyFormField>
|
||||
<KnowledgeBaseFormField></KnowledgeBaseFormField>
|
||||
</div>
|
||||
);
|
||||
|
||||
@ -68,8 +68,8 @@ export function ChatSettings({ switchSettingVisible }: ChatSettingsProps) {
|
||||
}, [data, form]);
|
||||
|
||||
return (
|
||||
<section className="p-5 w-[440px] ">
|
||||
<div className="flex justify-between items-center text-base">
|
||||
<section className="p-5 w-[440px] border-l">
|
||||
<div className="flex justify-between items-center text-base pb-2">
|
||||
Chat Settings
|
||||
<X className="size-4 cursor-pointer" onClick={switchSettingVisible} />
|
||||
</div>
|
||||
|
||||
@ -24,6 +24,7 @@ export function useChatSettingSchema() {
|
||||
optional: z.boolean(),
|
||||
}),
|
||||
),
|
||||
tavily_api_key: z.string().optional(),
|
||||
});
|
||||
|
||||
const formSchema = z.object({
|
||||
|
||||
155
web/src/pages/next-chats/chat/chat-box/multiple-chat-box.tsx
Normal file
155
web/src/pages/next-chats/chat/chat-box/multiple-chat-box.tsx
Normal file
@ -0,0 +1,155 @@
|
||||
import { NextMessageInput } from '@/components/message-input/next';
|
||||
import MessageItem from '@/components/message-item';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { MessageType } from '@/constants/chat';
|
||||
import {
|
||||
useFetchConversation,
|
||||
useFetchDialog,
|
||||
useGetChatSearchParams,
|
||||
} from '@/hooks/use-chat-request';
|
||||
import { useFetchUserInfo } from '@/hooks/user-setting-hooks';
|
||||
import { buildMessageUuidWithRole } from '@/utils/chat';
|
||||
import { Trash2 } from 'lucide-react';
|
||||
import { useCallback } from 'react';
|
||||
import {
|
||||
useGetSendButtonDisabled,
|
||||
useSendButtonDisabled,
|
||||
} from '../../hooks/use-button-disabled';
|
||||
import { useCreateConversationBeforeUploadDocument } from '../../hooks/use-create-conversation';
|
||||
import { useSendMessage } from '../../hooks/use-send-chat-message';
|
||||
import { buildMessageItemReference } from '../../utils';
|
||||
import { useAddChatBox } from '../use-add-box';
|
||||
|
||||
type MultipleChatBoxProps = {
|
||||
controller: AbortController;
|
||||
chatBoxIds: string[];
|
||||
} & Pick<ReturnType<typeof useAddChatBox>, 'removeChatBox'>;
|
||||
|
||||
type ChatCardProps = { id: string } & Pick<
|
||||
MultipleChatBoxProps,
|
||||
'controller' | 'removeChatBox'
|
||||
>;
|
||||
|
||||
function ChatCard({ controller, removeChatBox, id }: ChatCardProps) {
|
||||
const {
|
||||
value,
|
||||
// scrollRef,
|
||||
messageContainerRef,
|
||||
sendLoading,
|
||||
derivedMessages,
|
||||
handleInputChange,
|
||||
handlePressEnter,
|
||||
regenerateMessage,
|
||||
removeMessageById,
|
||||
stopOutputMessage,
|
||||
} = useSendMessage(controller);
|
||||
|
||||
const { data: userInfo } = useFetchUserInfo();
|
||||
const { data: currentDialog } = useFetchDialog();
|
||||
const { data: conversation } = useFetchConversation();
|
||||
|
||||
const handleRemoveChatBox = useCallback(() => {
|
||||
removeChatBox(id);
|
||||
}, [id, removeChatBox]);
|
||||
|
||||
return (
|
||||
<Card className="bg-transparent border flex-1">
|
||||
<CardHeader className="border-b px-5 py-3">
|
||||
<CardTitle className="flex justify-between items-center">
|
||||
<div>
|
||||
<span className="text-base">Card Title</span>
|
||||
<Button variant={'ghost'} className="ml-2">
|
||||
GPT-4
|
||||
</Button>
|
||||
</div>
|
||||
<Button variant={'ghost'} onClick={handleRemoveChatBox}>
|
||||
<Trash2 />
|
||||
</Button>
|
||||
</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<div ref={messageContainerRef} className="flex-1 overflow-auto min-h-0">
|
||||
<div className="w-full">
|
||||
{derivedMessages?.map((message, i) => {
|
||||
return (
|
||||
<MessageItem
|
||||
loading={
|
||||
message.role === MessageType.Assistant &&
|
||||
sendLoading &&
|
||||
derivedMessages.length - 1 === i
|
||||
}
|
||||
key={buildMessageUuidWithRole(message)}
|
||||
item={message}
|
||||
nickname={userInfo.nickname}
|
||||
avatar={userInfo.avatar}
|
||||
avatarDialog={currentDialog.icon}
|
||||
reference={buildMessageItemReference(
|
||||
{
|
||||
message: derivedMessages,
|
||||
reference: conversation.reference,
|
||||
},
|
||||
message,
|
||||
)}
|
||||
// clickDocumentButton={clickDocumentButton}
|
||||
index={i}
|
||||
removeMessageById={removeMessageById}
|
||||
regenerateMessage={regenerateMessage}
|
||||
sendLoading={sendLoading}
|
||||
></MessageItem>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
{/* <div ref={scrollRef} /> */}
|
||||
</div>
|
||||
</CardContent>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
export function MultipleChatBox({
|
||||
controller,
|
||||
chatBoxIds,
|
||||
removeChatBox,
|
||||
}: MultipleChatBoxProps) {
|
||||
const {
|
||||
value,
|
||||
sendLoading,
|
||||
handleInputChange,
|
||||
handlePressEnter,
|
||||
stopOutputMessage,
|
||||
} = useSendMessage(controller);
|
||||
|
||||
const { createConversationBeforeUploadDocument } =
|
||||
useCreateConversationBeforeUploadDocument();
|
||||
const { conversationId } = useGetChatSearchParams();
|
||||
const disabled = useGetSendButtonDisabled();
|
||||
const sendDisabled = useSendButtonDisabled(value);
|
||||
return (
|
||||
<section className="h-full flex flex-col">
|
||||
<div className="flex gap-4 flex-1 px-5 pb-12">
|
||||
{chatBoxIds.map((id) => (
|
||||
<ChatCard
|
||||
key={id}
|
||||
controller={controller}
|
||||
id={id}
|
||||
removeChatBox={removeChatBox}
|
||||
></ChatCard>
|
||||
))}
|
||||
</div>
|
||||
<NextMessageInput
|
||||
disabled={disabled}
|
||||
sendDisabled={sendDisabled}
|
||||
sendLoading={sendLoading}
|
||||
value={value}
|
||||
onInputChange={handleInputChange}
|
||||
onPressEnter={handlePressEnter}
|
||||
conversationId={conversationId}
|
||||
createConversationBeforeUploadDocument={
|
||||
createConversationBeforeUploadDocument
|
||||
}
|
||||
stopOutputMessage={stopOutputMessage}
|
||||
/>
|
||||
</section>
|
||||
);
|
||||
}
|
||||
@ -11,16 +11,16 @@ import { buildMessageUuidWithRole } from '@/utils/chat';
|
||||
import {
|
||||
useGetSendButtonDisabled,
|
||||
useSendButtonDisabled,
|
||||
} from '../hooks/use-button-disabled';
|
||||
import { useCreateConversationBeforeUploadDocument } from '../hooks/use-create-conversation';
|
||||
import { useSendMessage } from '../hooks/use-send-chat-message';
|
||||
import { buildMessageItemReference } from '../utils';
|
||||
} from '../../hooks/use-button-disabled';
|
||||
import { useCreateConversationBeforeUploadDocument } from '../../hooks/use-create-conversation';
|
||||
import { useSendMessage } from '../../hooks/use-send-chat-message';
|
||||
import { buildMessageItemReference } from '../../utils';
|
||||
|
||||
interface IProps {
|
||||
controller: AbortController;
|
||||
}
|
||||
|
||||
export function ChatBox({ controller }: IProps) {
|
||||
export function SingleChatBox({ controller }: IProps) {
|
||||
const {
|
||||
value,
|
||||
// scrollRef,
|
||||
@ -43,7 +43,7 @@ export function ChatBox({ controller }: IProps) {
|
||||
const sendDisabled = useSendButtonDisabled(value);
|
||||
|
||||
return (
|
||||
<section className="border-x flex flex-col p-5 flex-1 min-w-0">
|
||||
<section className="flex flex-col p-5 h-full">
|
||||
<div ref={messageContainerRef} className="flex-1 overflow-auto min-h-0">
|
||||
<div className="w-full">
|
||||
{derivedMessages?.map((message, i) => {
|
||||
@ -7,14 +7,20 @@ import {
|
||||
BreadcrumbPage,
|
||||
BreadcrumbSeparator,
|
||||
} from '@/components/ui/breadcrumb';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { useSetModalState } from '@/hooks/common-hooks';
|
||||
import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks';
|
||||
import { useFetchDialog } from '@/hooks/use-chat-request';
|
||||
import { cn } from '@/lib/utils';
|
||||
import { Plus } from 'lucide-react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useHandleClickConversationCard } from '../hooks/use-click-card';
|
||||
import { ChatSettings } from './app-settings/chat-settings';
|
||||
import { ChatBox } from './chat-box';
|
||||
import { MultipleChatBox } from './chat-box/multiple-chat-box';
|
||||
import { SingleChatBox } from './chat-box/single-chat-box';
|
||||
import { Sessions } from './sessions';
|
||||
import { useAddChatBox } from './use-add-box';
|
||||
|
||||
export default function Chat() {
|
||||
const { navigateToChatList } = useNavigatePage();
|
||||
@ -24,9 +30,16 @@ export default function Chat() {
|
||||
useHandleClickConversationCard();
|
||||
const { visible: settingVisible, switchVisible: switchSettingVisible } =
|
||||
useSetModalState(true);
|
||||
const {
|
||||
removeChatBox,
|
||||
addChatBox,
|
||||
chatBoxIds,
|
||||
hasSingleChatBox,
|
||||
hasThreeChatBox,
|
||||
} = useAddChatBox();
|
||||
|
||||
return (
|
||||
<section className="h-full flex flex-col">
|
||||
<section className="h-full flex flex-col pr-5">
|
||||
<PageHeader>
|
||||
<Breadcrumb>
|
||||
<BreadcrumbList>
|
||||
@ -43,18 +56,52 @@ export default function Chat() {
|
||||
</Breadcrumb>
|
||||
</PageHeader>
|
||||
<div className="flex flex-1 min-h-0">
|
||||
<div className="flex flex-1 min-w-0">
|
||||
<Sessions
|
||||
handleConversationCardClick={handleConversationCardClick}
|
||||
switchSettingVisible={switchSettingVisible}
|
||||
></Sessions>
|
||||
<ChatBox controller={controller}></ChatBox>
|
||||
</div>
|
||||
{settingVisible && (
|
||||
<ChatSettings
|
||||
switchSettingVisible={switchSettingVisible}
|
||||
></ChatSettings>
|
||||
)}
|
||||
<Sessions
|
||||
handleConversationCardClick={handleConversationCardClick}
|
||||
switchSettingVisible={switchSettingVisible}
|
||||
></Sessions>
|
||||
|
||||
<Card className="flex-1 min-w-0 bg-transparent border h-full">
|
||||
<CardContent className="flex p-0 h-full">
|
||||
<Card className="flex flex-col flex-1 bg-transparent">
|
||||
<CardHeader
|
||||
className={cn('p-5', { 'border-b': hasSingleChatBox })}
|
||||
>
|
||||
<CardTitle className="flex justify-between items-center">
|
||||
<div className="text-base">
|
||||
Card Title
|
||||
<Button variant={'ghost'} className="ml-2">
|
||||
GPT-4
|
||||
</Button>
|
||||
</div>
|
||||
<Button
|
||||
variant={'ghost'}
|
||||
onClick={addChatBox}
|
||||
disabled={hasThreeChatBox}
|
||||
>
|
||||
<Plus></Plus> Multiple Models
|
||||
</Button>
|
||||
</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent className="flex-1 p-0">
|
||||
{hasSingleChatBox ? (
|
||||
<SingleChatBox controller={controller}></SingleChatBox>
|
||||
) : (
|
||||
<MultipleChatBox
|
||||
chatBoxIds={chatBoxIds}
|
||||
controller={controller}
|
||||
removeChatBox={removeChatBox}
|
||||
></MultipleChatBox>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
{settingVisible && (
|
||||
<ChatSettings
|
||||
switchSettingVisible={switchSettingVisible}
|
||||
></ChatSettings>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
</section>
|
||||
);
|
||||
|
||||
26
web/src/pages/next-chats/chat/use-add-box.ts
Normal file
26
web/src/pages/next-chats/chat/use-add-box.ts
Normal file
@ -0,0 +1,26 @@
|
||||
import { useCallback, useState } from 'react';
|
||||
import { v4 as uuid } from 'uuid';
|
||||
|
||||
export function useAddChatBox() {
|
||||
const [ids, setIds] = useState<string[]>([uuid()]);
|
||||
|
||||
const hasSingleChatBox = ids.length === 1;
|
||||
|
||||
const hasThreeChatBox = ids.length === 3;
|
||||
|
||||
const addChatBox = useCallback(() => {
|
||||
setIds((prev) => [...prev, uuid()]);
|
||||
}, []);
|
||||
|
||||
const removeChatBox = useCallback((id: string) => {
|
||||
setIds((prev) => prev.filter((x) => x !== id));
|
||||
}, []);
|
||||
|
||||
return {
|
||||
chatBoxIds: ids,
|
||||
hasSingleChatBox,
|
||||
hasThreeChatBox,
|
||||
addChatBox,
|
||||
removeChatBox,
|
||||
};
|
||||
}
|
||||
@ -1,4 +1,4 @@
|
||||
import PasswordInput from '@/components/password-input';
|
||||
import PasswordInput from '@/components/originui/password-input';
|
||||
import { Avatar, AvatarFallback, AvatarImage } from '@/components/ui/avatar';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import {
|
||||
|
||||
Reference in New Issue
Block a user