add keyword extraction in graph (#1373)

### What problem does this PR solve?
#918 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
KevinHuSh
2024-07-04 15:57:25 +08:00
committed by GitHub
parent acd78c5ef2
commit 258c9ea644
11 changed files with 506 additions and 88 deletions

View File

@ -142,16 +142,19 @@ def run():
@manager.route('/reset', methods=['POST'])
@validate_request("canvas_id")
@validate_request("id")
@login_required
def reset():
req = request.json
try:
user_canvas = UserCanvasService.get_by_id(req["canvas_id"])
canvas = Canvas(user_canvas.dsl, current_user.id)
e, user_canvas = UserCanvasService.get_by_id(req["id"])
if not e:
return server_error_response("canvas not found.")
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
canvas.reset()
req["dsl"] = json.loads(str(canvas))
UserCanvasService.update_by_id(req["canvas_id"], dsl=req["dsl"])
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
return get_json_result(data=req["dsl"])
except Exception as e:
return server_error_response(e)

View File

@ -156,7 +156,7 @@ factory_infos = [{
"tags": "TEXT EMBEDDING, TEXT RE-RANK",
"status": "1",
},{
"name": "Minimax",
"name": "MiniMax",
"logo": "",
"tags": "LLM,TEXT EMBEDDING",
"status": "1",

View File

@ -102,19 +102,26 @@ class Canvas(ABC):
self.load()
def load(self):
assert self.dsl.get("components", {}).get("begin"), "There have to be a 'Begin' component."
self.components = self.dsl["components"]
cpn_nms = set([])
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
assert "Answer" in cpn_nms, "There have to be an 'Answer' component."
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
param = component_class(cpn["obj"]["component_name"] + "Param")()
param.update(cpn["obj"]["params"])
param.check()
cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
if cpn["obj"].component_name == "Categorize":
for _,desc in param.category_description.items():
for _, desc in param.category_description.items():
if desc["to"] not in cpn["downstream"]:
cpn["downstream"].append(desc["to"])
self.path = self.dsl["path"]
self.history = self.dsl["history"]
self.messages = self.dsl["messages"]
@ -140,7 +147,8 @@ class Canvas(ABC):
self.messages = []
self.answer = []
self.reference = []
self.components = {}
for k, cpn in self.components.items():
self.components[k]["obj"].reset()
self._embed_id = ""
def run(self, **kwargs):
@ -176,7 +184,7 @@ class Canvas(ABC):
ran += 1
prepare2run(self.components[self.path[-2][-1]]["downstream"])
while ran < len(self.path[-1]):
while 0 <= ran < len(self.path[-1]):
if DEBUG: print(ran, self.path)
cpn_id = self.path[-1][ran]
cpn = self.get_component(cpn_id)

View File

@ -418,6 +418,9 @@ class ComponentBase(ABC):
o = pd.DataFrame(o)
return self._param.output_var_name, o
def reset(self):
setattr(self._param, self._param.output_var_name, None)
def set_output(self, v: pd.DataFrame):
setattr(self._param, self._param.output_var_name, v)

View File

@ -0,0 +1,68 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from abc import ABC
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from graph.component import GenerateParam, Generate
from graph.settings import DEBUG
class KeywordExtractParam(GenerateParam):
"""
Define the KeywordExtract component parameters.
"""
def __init__(self):
super().__init__()
self.temperature = 0.5
self.prompt = ""
self.topn = 1
def check(self):
super().check()
def get_prompt(self):
self.prompt = """
- Role: You're a question analyzer.
- Requirements:
- Summarize user's question, and give top %s important keyword/phrase.
- Use comma as a delimiter to separate keywords/phrases.
- Answer format: (in language of user's question)
- keyword:
"""%self.topn
return self.prompt
class KeywordExtract(Generate, ABC):
component_name = "RewriteQuestion"
def _run(self, history, **kwargs):
q = ""
for r, c in self._canvas.history[::-1]:
if r == "user":
q += c
break
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": q}],
self._param.gen_conf())
ans = re.sub(r".*keyword:", "", ans).strip()
if DEBUG: print(ans, ":::::::::::::::::::::::::::::::::")
return KeywordExtract.be_output(ans)

View File

@ -1,5 +1,5 @@
{
"id": 0,
"id": 1,
"title": "HR call-out assistant(Chinese)",
"description": "A HR call-out assistant. It will introduce the given job, answer the candidates' question about this job. And the most important thing is that it will try to obtain the contact information of the candidates. What you need to do is to link a knowledgebase which contains job description in 'Retrieval' component.",
"canvas_type": "chatbot",

View File

@ -1,5 +1,5 @@
{
"id": 1,
"id": 2,
"title": "Customer service",
"description": "A call-in customer service chat bot. It will provide useful information about the products, answer customers' questions and soothe the customers' bad emotions.",
"canvas_type": "chatbot",
@ -106,7 +106,7 @@
"upstream": ["categorize:0"]
},
"generate:complain": {
"downstream": [],
"downstream": ["answer:0"],
"obj": {
"component_name": "Generate",
"params": {
@ -116,7 +116,7 @@
"prompt": "You are a customer support. the Customers complain even curse about the products but not specific enough. You need to ask him/her what's the specific problem with the product. Be nice, patient and concern to soothe your customers emotions at first place."
}
},
"upstream": ["categorize:0", "answer:0"]
"upstream": ["categorize:0"]
},
"message:get_contact": {
"downstream": ["answer:0"],
@ -286,13 +286,13 @@
{
"id": "reactflow__edge-answer:0a-generate:complaind",
"markerEnd": "logo",
"source": "answer:0",
"source": "generate:complain",
"sourceHandle": "a",
"style": {
"stroke": "rgb(202 197 245)",
"strokeWidth": 2
},
"target": "generate:complain",
"target": "answer:0",
"targetHandle": "d",
"type": "buttonEdge"
},

File diff suppressed because one or more lines are too long

View File

@ -1,5 +1,5 @@
{
"id": 2,
"id": 3,
"title": "Interpreter",
"description": "An interpreter. Type the content you want to translate and the object language like: Hi there => Spanish. Hava a try!",
"canvas_type": "chatbot",

View File

@ -1,79 +1,79 @@
{
"components": {
"begin": {
"obj": {
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["retrieval:0"],
"upstream": ["begin", "generate:0"]
},
"retrieval:0": {
"obj": {
"component_name": "Retrieval",
"params": {
"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.3,
"top_n": 6,
"top_k": 1024,
"rerank_id": "BAAI/bge-reranker-v2-m3",
"kb_ids": ["869a236818b811ef91dffa163e197198"],
"empty_response": "Sorry, knowledge base has noting related information."
}
},
"downstream": ["relevant:0"],
"upstream": ["answer:0", "rewrite:0"]
},
"relevant:0": {
"obj": {
"component_name": "Relevant",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.02,
"yes": "generate:0",
"no": "rewrite:0"
}
},
"downstream": ["generate:0", "rewrite:0"],
"upstream": ["retrieval:0"]
},
"generate:0": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please answer the question based on content of knowledge base. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\". Answers need to consider chat history.\n Knowledge base content is as following:\n {input}\n The above is the content of knowledge base.",
"temperature": 0.02
}
},
"downstream": ["answer:0"],
"upstream": ["relevant:0"]
},
"rewrite:0": {
"obj": {
"component_name": "RewriteQuestion",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.8
}
},
"downstream": ["retrieval:0"],
"upstream": ["relevant:0"]
}
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["retrieval:0"],
"upstream": ["begin", "generate:0", "switch:0"]
},
"retrieval:0": {
"obj": {
"component_name": "Retrieval",
"params": {
"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.3,
"top_n": 6,
"top_k": 1024,
"rerank_id": "BAAI/bge-reranker-v2-m3",
"kb_ids": ["869a236818b811ef91dffa163e197198"],
"empty_response": "Sorry, knowledge base has noting related information."
}
},
"downstream": ["relevant:0"],
"upstream": ["answer:0", "rewrite:0"]
},
"relevant:0": {
"obj": {
"component_name": "Relevant",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.02,
"yes": "generate:0",
"no": "rewrite:0"
}
},
"downstream": ["generate:0", "rewrite:0"],
"upstream": ["retrieval:0"]
},
"generate:0": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please answer the question based on content of knowledge base. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\". Answers need to consider chat history.\n Knowledge base content is as following:\n {input}\n The above is the content of knowledge base.",
"temperature": 0.02
}
},
"downstream": ["answer:0"],
"upstream": ["relevant:0"]
},
"rewrite:0": {
"obj":{
"component_name": "RewriteQuestion",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.8
}
},
"downstream": ["retrieval:0"],
"upstream": ["relevant:0"]
}
},
"history": [],
"messages": [],
"path": [],
"reference": [],
"answer": []
}
}

View File

@ -95,6 +95,7 @@ class DeepSeekChat(Base):
if not base_url: base_url="https://api.deepseek.com/v1"
super().__init__(key, model_name, base_url)
class AzureChat(Base):
def __init__(self, key, model_name, **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")