Chat Use CVmodel (#1607)

### What problem does this PR solve?

#1230 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
H
2024-07-19 18:36:34 +08:00
committed by GitHub
parent 347cb61f26
commit 58df013722
4 changed files with 325 additions and 6 deletions

View File

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import json
import re
from copy import deepcopy
@ -26,6 +28,7 @@ from rag.app.resume import forbidden_select_fields4resume
from rag.nlp import keyword_extraction
from rag.nlp.search import index_name
from rag.utils import rmSpace, num_tokens_from_string, encoder
from api.utils.file_utils import get_project_base_directory
class DialogService(CommonService):
@ -73,6 +76,15 @@ def message_fit_in(msg, max_length=4000):
return max_length, msg
def llm_id2llm_type(llm_id):
fnm = os.path.join(get_project_base_directory(), "conf")
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
for llm_factory in llm_factories["factory_llm_infos"]:
for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]:
return llm["model_type"].strip(",")[-1]
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
llm = LLMService.query(llm_name=dialog.llm_id)
@ -91,7 +103,10 @@ def chat(dialog, messages, stream=True, **kwargs):
questions = [m["content"] for m in messages if m["role"] == "user"]
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
if llm_id2llm_type(dialog.llm_id) == "image2text":
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
else:
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
prompt_config = dialog.prompt_config
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
@ -328,7 +343,10 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
def relevant(tenant_id, llm_id, question, contents: list):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
if llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
prompt = """
You are a grader assessing relevance of a retrieved document to a user question.
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
@ -347,7 +365,10 @@ def relevant(tenant_id, llm_id, question, contents: list):
def rewrite(tenant_id, llm_id, question):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
if llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
prompt = """
You are an expert at query expansion to generate a paraphrasing of a question.
I can't retrieval relevant information from the knowledge base by using user's question directly.