Add resume parser and fix bugs (#59)

* Update .gitignore

* Update .gitignore

* Add resume parser and fix bugs
This commit is contained in:
KevinHuSh
2024-02-07 19:27:23 +08:00
committed by GitHub
parent eb8254e688
commit c5ea37cd30
16 changed files with 451 additions and 57 deletions

View File

@ -47,17 +47,20 @@ def list():
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
if not tenant_id:
return get_data_error_result(retmsg="Tenant not found!")
e, doc = DocumentService.get_by_id(doc_id)
if not e:
return get_data_error_result(retmsg="Document not found!")
query = {
"doc_ids": [doc_id], "page": page, "size": size, "question": question
}
if "available_int" in req:
query["available_int"] = int(req["available_int"])
sres = retrievaler.search(query, search.index_name(tenant_id))
res = {"total": sres.total, "chunks": []}
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
for id in sres.ids:
d = {
"chunk_id": id,
"content_with_weight": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_with_weight"],
"content_with_weight": rmSpace(sres.highlight[id]) if question else sres.field[id].get("content_with_weight", ""),
"doc_id": sres.field[id]["doc_id"],
"docnm_kwd": sres.field[id]["docnm_kwd"],
"important_kwd": sres.field[id].get("important_kwd", []),
@ -110,7 +113,7 @@ def get():
"important_kwd")
def set():
req = request.json
d = {"id": req["chunk_id"]}
d = {"id": req["chunk_id"], "content_with_weight": req["content_with_weight"]}
d["content_ltks"] = huqie.qie(req["content_with_weight"])
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
d["important_kwd"] = req["important_kwd"]
@ -181,11 +184,12 @@ def create():
md5 = hashlib.md5()
md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
chunck_id = md5.hexdigest()
d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_with_weight"])}
d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_with_weight"]), "content_with_weight": req["content_with_weight"]}
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
d["important_kwd"] = req.get("important_kwd", [])
d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", [])))
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
try:
e, doc = DocumentService.get_by_id(req["doc_id"])

View File

@ -13,16 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
from flask import request
from flask_login import login_required
from api.db.services.dialog_service import DialogService, ConversationService
from api.db import LLMType
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMService, LLMBundle
from api.settings import access_logger
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
from api.utils import get_uuid
from api.utils.api_utils import get_json_result
from rag.llm import ChatModel
from rag.nlp import retrievaler
from rag.nlp.search import index_name
from rag.utils import num_tokens_from_string, encoder
@ -163,6 +168,17 @@ def chat(dialog, messages, **kwargs):
if not llm:
raise LookupError("LLM(%s) not found"%dialog.llm_id)
llm = llm[0]
question = messages[-1]["content"]
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
## try to use sql if field mapping is good to go
if field_map:
markdown_tbl,chunks = use_sql(question, field_map, dialog.tenant_id, chat_mdl)
if markdown_tbl:
return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
prompt_config = dialog.prompt_config
for p in prompt_config["parameters"]:
if p["key"] == "knowledge":continue
@ -170,9 +186,6 @@ def chat(dialog, messages, **kwargs):
if p["key"] not in kwargs:
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
question = messages[-1]["content"]
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
dialog.vector_similarity_weight, top=1024, aggs=False)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
@ -196,4 +209,46 @@ def chat(dialog, messages, **kwargs):
vtweight=dialog.vector_similarity_weight)
for c in kbinfos["chunks"]:
if c.get("vector"):del c["vector"]
return {"answer": answer, "retrieval": kbinfos}
return {"answer": answer, "retrieval": kbinfos}
def use_sql(question,field_map, tenant_id, chat_mdl):
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构根据我的问题写出sql。"
user_promt = """
表名:{}
数据库表字段说明如下:
{}
问题:{}
请写出SQL。
""".format(
index_name(tenant_id),
"\n".join([f"{k}: {v}" for k,v in field_map.items()]),
question
)
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.1})
sql = re.sub(r".*?select ", "select ", sql, flags=re.IGNORECASE)
sql = re.sub(r" +", " ", sql)
if sql[:len("select ")].lower() != "select ":
return None, None
if sql[:len("select *")].lower() != "select *":
sql = "select doc_id,docnm_kwd," + sql[6:]
tbl = retrievaler.sql_retrieval(sql)
if not tbl: return None, None
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx|docnm_idx)]
clmns = "|".join([re.sub(r"/.*", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
rows = ["|".join([str(r[i]) for i in clmn_idx])+"|" for r in tbl["rows"]]
if not docid_idx or not docnm_idx:
access_logger.error("SQL missing field: " + sql)
return "\n".join([clmns, line, "\n".join(rows)]), []
rows = "\n".join([r+f"##{ii}$$" for ii,r in enumerate(rows)])
docid_idx = list(docid_idx)[0]
docnm_idx = list(docnm_idx)[0]
return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]

View File

@ -21,9 +21,6 @@ import flask
from elasticsearch_dsl import Q
from flask import request
from flask_login import login_required, current_user
from api.db.db_models import Task
from api.db.services.task_service import TaskService
from rag.nlp import search
from rag.utils import ELASTICSEARCH
from api.db.services import duplicate_name
@ -35,7 +32,7 @@ from api.db.services.document_service import DocumentService
from api.settings import RetCode
from api.utils.api_utils import get_json_result
from rag.utils.minio_conn import MINIO
from api.utils.file_utils import filename_type
from api.utils.file_utils import filename_type, thumbnail
@manager.route('/upload', methods=['POST'])
@ -78,7 +75,8 @@ def upload():
"type": filename_type(filename),
"name": filename,
"location": location,
"size": len(blob)
"size": len(blob),
"thumbnail": thumbnail(filename, blob)
})
return get_json_result(data=doc.to_json())
except Exception as e:

View File

@ -474,7 +474,7 @@ class Knowledgebase(DataBaseModel):
vector_similarity_weight = FloatField(default=0.3)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value)
parser_config = JSONField(null=False, default={"from_page":0, "to_page": 100000})
parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted1: validate)", default="1")
def __str__(self):
@ -489,7 +489,7 @@ class Document(DataBaseModel):
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
kb_id = CharField(max_length=256, null=False, index=True)
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
parser_config = JSONField(null=False, default={"from_page":0, "to_page": 100000})
parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
type = CharField(max_length=32, null=False, help_text="file extension")
created_by = CharField(max_length=32, null=False, help_text="who created it")

View File

@ -21,5 +21,6 @@ class DialogService(CommonService):
model = Dialog
class ConversationService(CommonService):
model = Conversation

View File

@ -63,3 +63,31 @@ class KnowledgebaseService(CommonService):
d = kbs[0].to_dict()
d["embd_id"] = kbs[0].tenant.embd_id
return d
@classmethod
@DB.connection_context()
def update_parser_config(cls, id, config):
e, m = cls.get_by_id(id)
if not e:raise LookupError(f"knowledgebase({id}) not found.")
def dfs_update(old, new):
for k,v in new.items():
if k not in old:
old[k] = v
continue
if isinstance(v, dict):
assert isinstance(old[k], dict)
dfs_update(old[k], v)
else: old[k] = v
dfs_update(m.parser_config, config)
cls.update_by_id(id, m.parser_config)
@classmethod
@DB.connection_context()
def get_field_map(cls, ids):
conf = {}
for k in cls.get_by_ids(ids):
if k.parser_config and "field_map" in k.parser_config:
conf.update(k.parser_config)
return conf

View File

@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import json
import os
import re
from io import BytesIO
import fitz
from PIL import Image
from cachetools import LRUCache, cached
from ruamel.yaml import YAML
@ -150,4 +153,33 @@ def filename_type(filename):
return FileType.AURAL.value
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
return FileType.VISUAL
return FileType.VISUAL
def thumbnail(filename, blob):
filename = filename.lower()
if re.match(r".*\.pdf$", filename):
pdf = fitz.open(stream=blob, filetype="pdf")
pix = pdf[0].get_pixmap(matrix=fitz.Matrix(0.03, 0.03))
buffered = BytesIO()
Image.frombytes("RGB", [pix.width, pix.height],
pix.samples).save(buffered, format="png")
return "data:image/png;base64," + base64.b64encode(buffered.getvalue())
if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
return ("data:image/%s;base64,"%filename.split(".")[-1]) + base64.b64encode(Image.open(BytesIO(blob)).thumbnail((30, 30)).tobytes())
if re.match(r".*\.(ppt|pptx)$", filename):
import aspose.slides as slides
import aspose.pydrawing as drawing
try:
with slides.Presentation(BytesIO(blob)) as presentation:
buffered = BytesIO()
presentation.slides[0].get_thumbnail(0.03, 0.03).save(buffered, drawing.imaging.ImageFormat.png)
return "data:image/png;base64," + base64.b64encode(buffered.getvalue())
except Exception as e:
pass