mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Add resume parser and fix bugs (#59)
* Update .gitignore * Update .gitignore * Add resume parser and fix bugs
This commit is contained in:
@ -47,17 +47,20 @@ def list():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(retmsg="Tenant not found!")
|
||||
e, doc = DocumentService.get_by_id(doc_id)
|
||||
if not e:
|
||||
return get_data_error_result(retmsg="Document not found!")
|
||||
query = {
|
||||
"doc_ids": [doc_id], "page": page, "size": size, "question": question
|
||||
}
|
||||
if "available_int" in req:
|
||||
query["available_int"] = int(req["available_int"])
|
||||
sres = retrievaler.search(query, search.index_name(tenant_id))
|
||||
res = {"total": sres.total, "chunks": []}
|
||||
res = {"total": sres.total, "chunks": [], "doc": doc.to_dict()}
|
||||
for id in sres.ids:
|
||||
d = {
|
||||
"chunk_id": id,
|
||||
"content_with_weight": rmSpace(sres.highlight[id]) if question else sres.field[id]["content_with_weight"],
|
||||
"content_with_weight": rmSpace(sres.highlight[id]) if question else sres.field[id].get("content_with_weight", ""),
|
||||
"doc_id": sres.field[id]["doc_id"],
|
||||
"docnm_kwd": sres.field[id]["docnm_kwd"],
|
||||
"important_kwd": sres.field[id].get("important_kwd", []),
|
||||
@ -110,7 +113,7 @@ def get():
|
||||
"important_kwd")
|
||||
def set():
|
||||
req = request.json
|
||||
d = {"id": req["chunk_id"]}
|
||||
d = {"id": req["chunk_id"], "content_with_weight": req["content_with_weight"]}
|
||||
d["content_ltks"] = huqie.qie(req["content_with_weight"])
|
||||
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
||||
d["important_kwd"] = req["important_kwd"]
|
||||
@ -181,11 +184,12 @@ def create():
|
||||
md5 = hashlib.md5()
|
||||
md5.update((req["content_with_weight"] + req["doc_id"]).encode("utf-8"))
|
||||
chunck_id = md5.hexdigest()
|
||||
d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_with_weight"])}
|
||||
d = {"id": chunck_id, "content_ltks": huqie.qie(req["content_with_weight"]), "content_with_weight": req["content_with_weight"]}
|
||||
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
|
||||
d["important_kwd"] = req.get("important_kwd", [])
|
||||
d["important_tks"] = huqie.qie(" ".join(req.get("important_kwd", [])))
|
||||
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
|
||||
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
|
||||
|
||||
try:
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
|
||||
@ -13,16 +13,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from api.db.services.dialog_service import DialogService, ConversationService
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMService, LLMBundle
|
||||
from api.settings import access_logger
|
||||
from api.utils.api_utils import server_error_response, get_data_error_result, validate_request
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_json_result
|
||||
from rag.llm import ChatModel
|
||||
from rag.nlp import retrievaler
|
||||
from rag.nlp.search import index_name
|
||||
from rag.utils import num_tokens_from_string, encoder
|
||||
|
||||
|
||||
@ -163,6 +168,17 @@ def chat(dialog, messages, **kwargs):
|
||||
if not llm:
|
||||
raise LookupError("LLM(%s) not found"%dialog.llm_id)
|
||||
llm = llm[0]
|
||||
question = messages[-1]["content"]
|
||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
## try to use sql if field mapping is good to go
|
||||
if field_map:
|
||||
markdown_tbl,chunks = use_sql(question, field_map, dialog.tenant_id, chat_mdl)
|
||||
if markdown_tbl:
|
||||
return {"answer": markdown_tbl, "retrieval": {"chunks": chunks}}
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["key"] == "knowledge":continue
|
||||
@ -170,9 +186,6 @@ def chat(dialog, messages, **kwargs):
|
||||
if p["key"] not in kwargs:
|
||||
prompt_config["system"] = prompt_config["system"].replace("{%s}"%p["key"], " ")
|
||||
|
||||
question = messages[-1]["content"]
|
||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
kbinfos = retrievaler.retrieval(question, embd_mdl, dialog.tenant_id, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight, top=1024, aggs=False)
|
||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||
@ -196,4 +209,46 @@ def chat(dialog, messages, **kwargs):
|
||||
vtweight=dialog.vector_similarity_weight)
|
||||
for c in kbinfos["chunks"]:
|
||||
if c.get("vector"):del c["vector"]
|
||||
return {"answer": answer, "retrieval": kbinfos}
|
||||
return {"answer": answer, "retrieval": kbinfos}
|
||||
|
||||
|
||||
def use_sql(question,field_map, tenant_id, chat_mdl):
|
||||
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据我的问题写出sql。"
|
||||
user_promt = """
|
||||
表名:{};
|
||||
数据库表字段说明如下:
|
||||
{}
|
||||
|
||||
问题:{}
|
||||
请写出SQL。
|
||||
""".format(
|
||||
index_name(tenant_id),
|
||||
"\n".join([f"{k}: {v}" for k,v in field_map.items()]),
|
||||
question
|
||||
)
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {"temperature": 0.1})
|
||||
sql = re.sub(r".*?select ", "select ", sql, flags=re.IGNORECASE)
|
||||
sql = re.sub(r" +", " ", sql)
|
||||
if sql[:len("select ")].lower() != "select ":
|
||||
return None, None
|
||||
if sql[:len("select *")].lower() != "select *":
|
||||
sql = "select doc_id,docnm_kwd," + sql[6:]
|
||||
|
||||
tbl = retrievaler.sql_retrieval(sql)
|
||||
if not tbl: return None, None
|
||||
|
||||
docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"])
|
||||
docnm_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||||
clmn_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx|docnm_idx)]
|
||||
|
||||
clmns = "|".join([re.sub(r"/.*", "", field_map.get(tbl["columns"][i]["name"], f"C{i}")) for i in clmn_idx]) + "|原文"
|
||||
line = "|".join(["------" for _ in range(len(clmn_idx))]) + "|------"
|
||||
rows = ["|".join([str(r[i]) for i in clmn_idx])+"|" for r in tbl["rows"]]
|
||||
if not docid_idx or not docnm_idx:
|
||||
access_logger.error("SQL missing field: " + sql)
|
||||
return "\n".join([clmns, line, "\n".join(rows)]), []
|
||||
|
||||
rows = "\n".join([r+f"##{ii}$$" for ii,r in enumerate(rows)])
|
||||
docid_idx = list(docid_idx)[0]
|
||||
docnm_idx = list(docnm_idx)[0]
|
||||
return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]
|
||||
|
||||
@ -21,9 +21,6 @@ import flask
|
||||
from elasticsearch_dsl import Q
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
|
||||
from api.db.db_models import Task
|
||||
from api.db.services.task_service import TaskService
|
||||
from rag.nlp import search
|
||||
from rag.utils import ELASTICSEARCH
|
||||
from api.db.services import duplicate_name
|
||||
@ -35,7 +32,7 @@ from api.db.services.document_service import DocumentService
|
||||
from api.settings import RetCode
|
||||
from api.utils.api_utils import get_json_result
|
||||
from rag.utils.minio_conn import MINIO
|
||||
from api.utils.file_utils import filename_type
|
||||
from api.utils.file_utils import filename_type, thumbnail
|
||||
|
||||
|
||||
@manager.route('/upload', methods=['POST'])
|
||||
@ -78,7 +75,8 @@ def upload():
|
||||
"type": filename_type(filename),
|
||||
"name": filename,
|
||||
"location": location,
|
||||
"size": len(blob)
|
||||
"size": len(blob),
|
||||
"thumbnail": thumbnail(filename, blob)
|
||||
})
|
||||
return get_json_result(data=doc.to_json())
|
||||
except Exception as e:
|
||||
|
||||
@ -474,7 +474,7 @@ class Knowledgebase(DataBaseModel):
|
||||
vector_similarity_weight = FloatField(default=0.3)
|
||||
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID", default=ParserType.GENERAL.value)
|
||||
parser_config = JSONField(null=False, default={"from_page":0, "to_page": 100000})
|
||||
parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
|
||||
status = CharField(max_length=1, null=True, help_text="is it validate(0: wasted,1: validate)", default="1")
|
||||
|
||||
def __str__(self):
|
||||
@ -489,7 +489,7 @@ class Document(DataBaseModel):
|
||||
thumbnail = TextField(null=True, help_text="thumbnail base64 string")
|
||||
kb_id = CharField(max_length=256, null=False, index=True)
|
||||
parser_id = CharField(max_length=32, null=False, help_text="default parser ID")
|
||||
parser_config = JSONField(null=False, default={"from_page":0, "to_page": 100000})
|
||||
parser_config = JSONField(null=False, default={"pages":[[0,1000000]]})
|
||||
source_type = CharField(max_length=128, null=False, default="local", help_text="where dose this document from")
|
||||
type = CharField(max_length=32, null=False, help_text="file extension")
|
||||
created_by = CharField(max_length=32, null=False, help_text="who created it")
|
||||
|
||||
@ -21,5 +21,6 @@ class DialogService(CommonService):
|
||||
model = Dialog
|
||||
|
||||
|
||||
|
||||
class ConversationService(CommonService):
|
||||
model = Conversation
|
||||
|
||||
@ -63,3 +63,31 @@ class KnowledgebaseService(CommonService):
|
||||
d = kbs[0].to_dict()
|
||||
d["embd_id"] = kbs[0].tenant.embd_id
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_parser_config(cls, id, config):
|
||||
e, m = cls.get_by_id(id)
|
||||
if not e:raise LookupError(f"knowledgebase({id}) not found.")
|
||||
def dfs_update(old, new):
|
||||
for k,v in new.items():
|
||||
if k not in old:
|
||||
old[k] = v
|
||||
continue
|
||||
if isinstance(v, dict):
|
||||
assert isinstance(old[k], dict)
|
||||
dfs_update(old[k], v)
|
||||
else: old[k] = v
|
||||
dfs_update(m.parser_config, config)
|
||||
cls.update_by_id(id, m.parser_config)
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_field_map(cls, ids):
|
||||
conf = {}
|
||||
for k in cls.get_by_ids(ids):
|
||||
if k.parser_config and "field_map" in k.parser_config:
|
||||
conf.update(k.parser_config)
|
||||
return conf
|
||||
|
||||
|
||||
@ -13,11 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
import fitz
|
||||
from PIL import Image
|
||||
from cachetools import LRUCache, cached
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
@ -150,4 +153,33 @@ def filename_type(filename):
|
||||
return FileType.AURAL.value
|
||||
|
||||
if re.match(r".*\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico|mpg|mpeg|avi|rm|rmvb|mov|wmv|asf|dat|asx|wvx|mpe|mpa|mp4)$", filename):
|
||||
return FileType.VISUAL
|
||||
return FileType.VISUAL
|
||||
|
||||
|
||||
def thumbnail(filename, blob):
|
||||
filename = filename.lower()
|
||||
if re.match(r".*\.pdf$", filename):
|
||||
pdf = fitz.open(stream=blob, filetype="pdf")
|
||||
pix = pdf[0].get_pixmap(matrix=fitz.Matrix(0.03, 0.03))
|
||||
buffered = BytesIO()
|
||||
Image.frombytes("RGB", [pix.width, pix.height],
|
||||
pix.samples).save(buffered, format="png")
|
||||
return "data:image/png;base64," + base64.b64encode(buffered.getvalue())
|
||||
|
||||
if re.match(r".*\.(jpg|jpeg|png|tif|gif|icon|ico|webp)$", filename):
|
||||
return ("data:image/%s;base64,"%filename.split(".")[-1]) + base64.b64encode(Image.open(BytesIO(blob)).thumbnail((30, 30)).tobytes())
|
||||
|
||||
if re.match(r".*\.(ppt|pptx)$", filename):
|
||||
import aspose.slides as slides
|
||||
import aspose.pydrawing as drawing
|
||||
try:
|
||||
with slides.Presentation(BytesIO(blob)) as presentation:
|
||||
buffered = BytesIO()
|
||||
presentation.slides[0].get_thumbnail(0.03, 0.03).save(buffered, drawing.imaging.ImageFormat.png)
|
||||
return "data:image/png;base64," + base64.b64encode(buffered.getvalue())
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user