Add resume parser and fix bugs (#59)

* Update .gitignore

* Update .gitignore

* Add resume parser and fix bugs
This commit is contained in:
KevinHuSh
2024-02-07 19:27:23 +08:00
committed by GitHub
parent eb8254e688
commit c5ea37cd30
16 changed files with 451 additions and 57 deletions

View File

@ -47,17 +47,20 @@ def list():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!")
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_data_error_result(retmsg="Document not found!")
query = {
"doc_ids": [doc_id], "page": page, "size": size, "question": question
}
if "available_int" in req:
query["available_int"] = int(req["available_int"])
sres = retrievaler.search(query, search.index_name(tenant_id))
res = {"total": sres.total, "chunks": []}
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids:
d = {
"chunk_id": id,
"content_with_weight": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_with_weight"],
"content_with_weight": rmSpace(sres.highlight[id]) if question else sres.field[id].get("content_with_weight", ""),
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"],
"important_kwd": sres.field[id].get("important_kwd", []),
@ -110,7 +113,7 @@ def get():
"important_kwd")
def set():
req = request.json
d = {"id": req["chunk_id"]}
d = {"id": req["chunk_id"], "content_with_weight": req["content_with_weight"]}
d["content_ltks"] = huqie.qie(req["content_with_weight"])
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
d["important_kwd"] = req["important_kwd"]
@ -181,11 +184,12 @@ def create():
md5 = hashlib.md5()
md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
chunck_id = md5.hexdigest()
d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_with_weight"])}
d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_with_weight"]), "content_with_weight": req["content_with_weight"]}
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
d["important_kwd"] = req.get("important_kwd", [])
d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", [])))
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
try:
e, doc = DocumentService.get_by_id(req["doc_id"])

View File

@ -13,16 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from flask import request
from flask_login import login_required
from api.db.services.dialog_service import DialogService, ConversationService
from api.db import LLMType
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, LLMBundle
from api.settings import access_logger
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
from rag.llm import ChatModel
from rag.nlp import retrievaler
from rag.nlp.search import index_name
from rag.utils import num_tokens_from_string, encoder
@ -163,6 +168,17 @@ def chat(dialog, messages, **kwargs):
if not llm:
raise LookupError("LLM(%s) not found"%dialog.llm_id)
llm = llm[0]
question = messages[-1]["content"]
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
## try to use sql if field mapping is good to go
if field_map:
markdown_tbl,chunks = use_sql(question, field_map, dialog.tenant_id, chat_mdl)
if markdown_tbl:
return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
prompt_config = dialog.prompt_config
for p in prompt_config["parameters"]:
if p["key"] == "knowledge":continue
@ -170,9 +186,6 @@ def chat(dialog, messages, **kwargs):
if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
question = messages[-1]["content"]
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
dialog.vector_similarity_weight, top=1024, aggs=False)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
@ -196,4 +209,46 @@ def chat(dialog, messages, **kwargs):
vtweight=dialog.vector_similarity_weight)
for c in kbinfos["chunks"]:
if c.get("vector"):del c["vector"]
return {"answer": answer, "retrieval": kbinfos}
return {"answer": answer, "retrieval": kbinfos}
def use_sql(question,field_map, tenant_id, chat_mdl):
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构根据我的问题写出sql。"
user_promt = """
表名:{}
数据库表字段说明如下:
{}
问题:{}
请写出SQL。
""".format(
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k,v in field_map.items()]),
question
)
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.1})
sql = re.sub(r".*?select ", "select ", sql, flags=re.IGNORECASE)
sql = re.sub(r" +", " ", sql)
if sql[:len("select ")].lower() != "select ":
return None, None
if sql[:len("select *")].lower() != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:]
tbl = retrievaler.sql_retrieval(sql)
if not tbl: return None, None
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx|docnm_idx)]
clmns = "|".join([re.sub(r"/.*", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
rows = ["|".join([str(r[i]) for i in clmn_idx])+"|" for r in tbl["rows"]]
if not docid_idx or not docnm_idx:
access_logger.error("SQL missing field: " + sql)
return "\n".join([clmns, line, "\n".join(rows)]), []
rows = "\n".join([r+f"##{ii}$$" for ii,r in enumerate(rows)])
docid_idx = list(docid_idx)[0]
docnm_idx = list(docnm_idx)[0]
return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]

View File

@ -21,9 +21,6 @@ import flask
from elasticsearch_dsl import Q
from flask import request
from flask_login import login_required, current_user
from api.db.db_models import Task
from api.db.services.task_service import TaskService
from rag.nlp import search
from rag.utils import ELASTICSEARCH
from api.db.services import duplicate_name
@ -35,7 +32,7 @@ from api.db.services.document_service import DocumentService
from api.settings import RetCode
from api.utils.api_utils import get_json_result
from rag.utils.minio_conn import MINIO
from api.utils.file_utils import filename_type
from api.utils.file_utils import filename_type, thumbnail
@manager.route('/upload', methods=['POST'])
@ -78,7 +75,8 @@ def upload():
"type": filename_type(filename),
"name": filename,
"location": location,
"size": len(blob)
"size": len(blob),
"thumbnail": thumbnail(filename, blob)
})
return get_json_result(data=doc.to_json())
except Exception as e: