init README of deepdoc, add picture processer. (#71)

* init README of deepdoc, add picture processer.

* add resume parsing
This commit is contained in:
KevinHuSh
2024-02-23 18:28:12 +08:00
committed by GitHub
parent d32322c081
commit 7fd1eca582
42 changed files with 58319 additions and 350 deletions

View File

@ -12,7 +12,7 @@
#
import copy
import re
from deepdoc.parser import bullets_category, is_english, tokenize, remove_contents_table, \
from rag.nlp import bullets_category, is_english, tokenize, remove_contents_table, \
hierarchical_merge, make_colon_as_title, naive_merge, random_choices
from rag.nlp import huqie
from deepdoc.parser import PdfParser, DocxParser
@ -47,7 +47,7 @@ class Pdf(PdfParser):
return [(b["text"] + self._line_tag(b, zoomin), b.get("layoutno","")) for b in self.boxes], tbls
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
"""
Supported file formats are docx, pdf, txt.
Since a book is long and not all the parts are useful, if it's a PDF,
@ -94,7 +94,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
sections = [t for t, _ in sections]
# is it English
eng = is_english(random_choices(sections, k=218))
eng = lang.lower() == "english"#is_english(random_choices(sections, k=218))
res = []
# add tables

View File

@ -14,7 +14,7 @@ import copy
import re
from io import BytesIO
from docx import Document
from deepdoc.parser import bullets_category, is_english, tokenize, remove_contents_table, hierarchical_merge, \
from rag.nlp import bullets_category, is_english, tokenize, remove_contents_table, hierarchical_merge, \
make_colon_as_title
from rag.nlp import huqie
from deepdoc.parser import PdfParser, DocxParser
@ -68,7 +68,7 @@ class Pdf(PdfParser):
return [b["text"] + self._line_tag(b, zoomin) for b in self.boxes]
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
"""
Supported file formats are docx, pdf, txt.
"""
@ -106,7 +106,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
else: raise NotImplementedError("file type not supported yet(docx, pdf, txt supported)")
# is it English
eng = is_english(sections)
eng = lang.lower() == "english"#is_english(sections)
# Remove 'Contents' part
remove_contents_table(sections, eng)

View File

@ -1,7 +1,6 @@
import copy
import re
from deepdoc.parser import tokenize
from rag.nlp import huqie
from rag.nlp import huqie, tokenize
from deepdoc.parser import PdfParser
from rag.utils import num_tokens_from_string
@ -57,7 +56,7 @@ class Pdf(PdfParser):
return [b["text"] + self._line_tag(b, zoomin) for b in self.boxes], tbls
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
"""
Only pdf is supported.
"""
@ -74,7 +73,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
doc["title_tks"] = huqie.qie(re.sub(r"\.[a-zA-Z]+$", "", doc["docnm_kwd"]))
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
# is it English
eng = pdf_parser.is_english
eng = lang.lower() == "english"#pdf_parser.is_english
res = []
# add tables

View File

@ -13,8 +13,7 @@
import copy
import re
from rag.app import laws
from deepdoc.parser import is_english, tokenize, naive_merge
from rag.nlp import huqie
from rag.nlp import huqie, is_english, tokenize, naive_merge
from deepdoc.parser import PdfParser
from rag.settings import cron_logger
@ -38,7 +37,7 @@ class Pdf(PdfParser):
return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes]
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
"""
Supported file formats are docx, pdf, txt.
This method apply the naive ways to chunk files.
@ -80,7 +79,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
parser_config = kwargs.get("parser_config", {"chunk_token_num": 128, "delimiter": "\n!?。;!?"})
cks = naive_merge(sections, parser_config["chunk_token_num"], parser_config["delimiter"])
eng = is_english(cks)
eng = lang.lower() == "english"#is_english(cks)
res = []
# wrap up to es documents
for ck in cks:

View File

@ -15,8 +15,7 @@ import re
from collections import Counter
from api.db import ParserType
from deepdoc.parser import tokenize
from rag.nlp import huqie
from rag.nlp import huqie, tokenize
from deepdoc.parser import PdfParser
import numpy as np
from rag.utils import num_tokens_from_string
@ -140,7 +139,7 @@ class Pdf(PdfParser):
}
def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **kwargs):
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
"""
Only pdf is supported.
The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
@ -156,7 +155,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, callback=None, **k
doc["title_sm_tks"] = huqie.qieqie(doc["title_tks"])
doc["authors_sm_tks"] = huqie.qieqie(doc["authors_tks"])
# is it English
eng = pdf_parser.is_english
eng = lang.lower() == "english"#pdf_parser.is_english
print("It's English.....", eng)
res = []

56
rag/app/picture.py Normal file
View File

@ -0,0 +1,56 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import io
import numpy as np
from PIL import Image
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from rag.nlp import tokenize
from deepdoc.vision import OCR
ocr = OCR()
def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
try:
cv_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, lang=lang)
except Exception as e:
callback(prog=-1, msg=str(e))
return []
img = Image.open(io.BytesIO(binary))
doc = {
"docnm_kwd": filename,
"image": img
}
bxs = ocr(np.array(img))
txt = "\n".join([t[0] for _, t in bxs if t[0]])
eng = lang.lower() == "english"
callback(0.4, "Finish OCR: (%s ...)" % txt[:12])
if (eng and len(txt.split(" ")) > 32) or len(txt) > 32:
tokenize(doc, txt, eng)
callback(0.8, "OCR results is too long to use CV LLM.")
return [doc]
try:
callback(0.4, "Use CV LLM to describe the picture.")
ans = cv_mdl.describe(binary)
callback(0.8, "CV LLM respoond: %s ..." % ans[:32])
txt += "\n" + ans
tokenize(doc, txt, eng)
return [doc]
except Exception as e:
callback(prog=-1, msg=str(e))
return []

View File

@ -13,46 +13,14 @@
import copy
import re
from io import BytesIO
from pptx import Presentation
from deepdoc.parser import tokenize, is_english
from rag.nlp import tokenize, is_english
from rag.nlp import huqie
from deepdoc.parser import PdfParser
from deepdoc.parser import PdfParser, PptParser
class Ppt(object):
def __init__(self):
super().__init__()
def __extract(self, shape):
if shape.shape_type == 19:
tb = shape.table
rows = []
for i in range(1, len(tb.rows)):
rows.append("; ".join([tb.cell(0, j).text + ": " + tb.cell(i, j).text for j in range(len(tb.columns)) if tb.cell(i, j)]))
return "\n".join(rows)
if shape.has_text_frame:
return shape.text_frame.text
if shape.shape_type == 6:
texts = []
for p in shape.shapes:
t = self.__extract(p)
if t: texts.append(t)
return "\n".join(texts)
class Ppt(PptParser):
def __call__(self, fnm, from_page, to_page, callback=None):
ppt = Presentation(fnm) if isinstance(
fnm, str) else Presentation(
BytesIO(fnm))
txts = []
self.total_page = len(ppt.slides)
for i, slide in enumerate(ppt.slides[from_page: to_page]):
texts = []
for shape in slide.shapes:
txt = self.__extract(shape)
if txt: texts.append(txt)
txts.append("\n".join(texts))
txts = super.__call__(fnm, from_page, to_page)
callback(0.5, "Text extraction finished.")
import aspose.slides as slides

View File

@ -14,7 +14,7 @@ import re
from io import BytesIO
from nltk import word_tokenize
from openpyxl import load_workbook
from deepdoc.parser import is_english, random_choices
from rag.nlp import is_english, random_choices
from rag.nlp import huqie, stemmer
from deepdoc.parser import ExcelParser
@ -81,7 +81,7 @@ def beAdoc(d, q, a, eng):
return d
def chunk(filename, binary=None, callback=None, **kwargs):
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
"""
Excel and csv(txt) format files are supported.
If the file is in excel format, there should be 2 column question and answer without header.
@ -113,7 +113,7 @@ def chunk(filename, binary=None, callback=None, **kwargs):
break
txt += l
lines = txt.split("\n")
eng = is_english([rmPrefix(l) for l in lines[:100]])
eng = lang.lower() == "english"#is_english([rmPrefix(l) for l in lines[:100]])
fails = []
for i, line in enumerate(lines):
arr = [l for l in line.split("\t") if len(l) > 1]

View File

@ -20,8 +20,7 @@ from openpyxl import load_workbook
from dateutil.parser import parse as datetime_parse
from api.db.services.knowledgebase_service import KnowledgebaseService
from deepdoc.parser import is_english, tokenize
from rag.nlp import huqie
from rag.nlp import huqie, is_english, tokenize
from deepdoc.parser import ExcelParser
@ -112,7 +111,7 @@ def column_data_type(arr):
return arr, ty
def chunk(filename, binary=None, callback=None, **kwargs):
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
"""
Excel and csv(txt) format files are supported.
For csv or txt file, the delimiter between columns is TAB.
@ -192,7 +191,7 @@ def chunk(filename, binary=None, callback=None, **kwargs):
clmns_map = [(py_clmns[j] + fieds_map[clmn_tys[j]], clmns[j])
for i in range(len(clmns))]
eng = is_english(txts)
eng = lang.lower() == "english"#is_english(txts)
for ii, row in df.iterrows():
d = {}
row_txt = []

View File

@ -13,12 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import io
from abc import ABC
from PIL import Image
from openai import OpenAI
import os
import base64
from io import BytesIO
from api.utils import get_uuid
from api.utils.file_utils import get_project_base_directory
class Base(ABC):
def __init__(self, key, model_name):
@ -44,25 +50,26 @@ class Base(ABC):
{
"role": "user",
"content": [
{
"type": "text",
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等。",
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64}"
},
},
{
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
},
],
}
]
class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview"):
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese"):
self.client = OpenAI(api_key=key)
self.model_name = model_name
self.lang = lang
def describe(self, image, max_tokens=300):
b64 = self.image2base64(image)
@ -76,18 +83,40 @@ class GptV4(Base):
class QWenCV(Base):
def __init__(self, key, model_name="qwen-vl-chat-v1"):
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese"):
import dashscope
dashscope.api_key = key
self.model_name = model_name
self.lang = lang
def prompt(self, binary):
# stupid as hell
tmp_dir = get_project_base_directory("tmp")
if not os.path.exists(tmp_dir): os.mkdir(tmp_dir)
path = os.path.join(tmp_dir, "%s.jpg"%get_uuid())
Image.open(io.BytesIO(binary)).save(path)
return [
{
"role": "user",
"content": [
{
"image": f"file://{path}"
},
{
"text": "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out.",
},
],
}
]
def describe(self, image, max_tokens=300):
from http import HTTPStatus
from dashscope import MultiModalConversation
response = MultiModalConversation.call(model=self.model_name,
messages=self.prompt(self.image2base64(image)))
messages=self.prompt(image))
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'], response.usage.output_tokens
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
return response.message, 0
@ -95,9 +124,10 @@ from zhipuai import ZhipuAI
class Zhipu4V(Base):
def __init__(self, key, model_name="glm-4v"):
def __init__(self, key, model_name="glm-4v", lang="Chinese"):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
self.lang = lang
def describe(self, image, max_tokens=1024):
b64 = self.image2base64(image)

View File

@ -5,3 +5,219 @@ retrievaler = search.Dealer(ELASTICSEARCH)
from nltk.stem import PorterStemmer
stemmer = PorterStemmer()
import re
from nltk import word_tokenize
from . import huqie
from rag.utils import num_tokens_from_string
import random
BULLET_PATTERN = [[
r"第[零一二三四五六七八九十百0-9]+(分?编|部分)",
r"第[零一二三四五六七八九十百0-9]+章",
r"第[零一二三四五六七八九十百0-9]+节",
r"第[零一二三四五六七八九十百0-9]+条",
r"[\(][零一二三四五六七八九十百]+[\)]",
], [
r"第[0-9]+章",
r"第[0-9]+节",
r"[0-9]{,3}[\. 、]",
r"[0-9]{,2}\.[0-9]{,2}",
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
r"[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}\.[0-9]{,2}",
], [
r"第[零一二三四五六七八九十百0-9]+章",
r"第[零一二三四五六七八九十百0-9]+节",
r"[零一二三四五六七八九十百]+[ 、]",
r"[\(][零一二三四五六七八九十百]+[\)]",
r"[\(][0-9]{,2}[\)]",
], [
r"PART (ONE|TWO|THREE|FOUR|FIVE|SIX|SEVEN|EIGHT|NINE|TEN)",
r"Chapter (I+V?|VI*|XI|IX|X)",
r"Section [0-9]+",
r"Article [0-9]+"
]
]
def random_choices(arr, k):
k = min(len(arr), k)
return random.choices(arr, k=k)
def bullets_category(sections):
global BULLET_PATTERN
hits = [0] * len(BULLET_PATTERN)
for i, pro in enumerate(BULLET_PATTERN):
for sec in sections:
for p in pro:
if re.match(p, sec):
hits[i] += 1
break
maxium = 0
res = -1
for i, h in enumerate(hits):
if h <= maxium: continue
res = i
maxium = h
return res
def is_english(texts):
eng = 0
for t in texts:
if re.match(r"[a-zA-Z]{2,}", t.strip()):
eng += 1
if eng / len(texts) > 0.8:
return True
return False
def tokenize(d, t, eng):
d["content_with_weight"] = t
if eng:
t = re.sub(r"([a-z])-([a-z])", r"\1\2", t)
d["content_ltks"] = " ".join([stemmer.stem(w) for w in word_tokenize(t)])
else:
d["content_ltks"] = huqie.qie(t)
d["content_sm_ltks"] = huqie.qieqie(d["content_ltks"])
def remove_contents_table(sections, eng=False):
i = 0
while i < len(sections):
def get(i):
nonlocal sections
return (sections[i] if type(sections[i]) == type("") else sections[i][0]).strip()
if not re.match(r"(contents|目录|目次|table of contents|致谢|acknowledge)$",
re.sub(r"( | |\u3000)+", "", get(i).split("@@")[0], re.IGNORECASE)):
i += 1
continue
sections.pop(i)
if i >= len(sections): break
prefix = get(i)[:3] if not eng else " ".join(get(i).split(" ")[:2])
while not prefix:
sections.pop(i)
if i >= len(sections): break
prefix = get(i)[:3] if not eng else " ".join(get(i).split(" ")[:2])
sections.pop(i)
if i >= len(sections) or not prefix: break
for j in range(i, min(i + 128, len(sections))):
if not re.match(prefix, get(j)):
continue
for _ in range(i, j): sections.pop(i)
break
def make_colon_as_title(sections):
if not sections: return []
if type(sections[0]) == type(""): return sections
i = 0
while i < len(sections):
txt, layout = sections[i]
i += 1
txt = txt.split("@")[0].strip()
if not txt:
continue
if txt[-1] not in ":":
continue
txt = txt[::-1]
arr = re.split(r"([。?!!?;]| .)", txt)
if len(arr) < 2 or len(arr[1]) < 32:
continue
sections.insert(i - 1, (arr[0][::-1], "title"))
i += 1
def hierarchical_merge(bull, sections, depth):
if not sections or bull < 0: return []
if type(sections[0]) == type(""): sections = [(s, "") for s in sections]
sections = [(t,o) for t, o in sections if t and len(t.split("@")[0].strip()) > 1 and not re.match(r"[0-9]+$", t.split("@")[0].strip())]
bullets_size = len(BULLET_PATTERN[bull])
levels = [[] for _ in range(bullets_size + 2)]
def not_title(txt):
if re.match(r"第[零一二三四五六七八九十百0-9]+条", txt): return False
if len(txt) >= 128: return True
return re.search(r"[,;,。;!!]", txt)
for i, (txt, layout) in enumerate(sections):
for j, p in enumerate(BULLET_PATTERN[bull]):
if re.match(p, txt.strip()) and not not_title(txt):
levels[j].append(i)
break
else:
if re.search(r"(title|head)", layout):
levels[bullets_size].append(i)
else:
levels[bullets_size + 1].append(i)
sections = [t for t, _ in sections]
for s in sections: print("--", s)
def binary_search(arr, target):
if not arr: return -1
if target > arr[-1]: return len(arr) - 1
if target < arr[0]: return -1
s, e = 0, len(arr)
while e - s > 1:
i = (e + s) // 2
if target > arr[i]:
s = i
continue
elif target < arr[i]:
e = i
continue
else:
assert False
return s
cks = []
readed = [False] * len(sections)
levels = levels[::-1]
for i, arr in enumerate(levels[:depth]):
for j in arr:
if readed[j]: continue
readed[j] = True
cks.append([j])
if i + 1 == len(levels) - 1: continue
for ii in range(i + 1, len(levels)):
jj = binary_search(levels[ii], j)
if jj < 0: continue
if jj > cks[-1][-1]: cks[-1].pop(-1)
cks[-1].append(levels[ii][jj])
for ii in cks[-1]: readed[ii] = True
for i in range(len(cks)):
cks[i] = [sections[j] for j in cks[i][::-1]]
print("--------------\n", "\n* ".join(cks[i]))
return cks
def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?"):
if not sections: return []
if type(sections[0]) == type(""): sections = [(s, "") for s in sections]
cks = [""]
tk_nums = [0]
def add_chunk(t, pos):
nonlocal cks, tk_nums, delimiter
tnum = num_tokens_from_string(t)
if tnum < 8: pos = ""
if tk_nums[-1] > chunk_token_num:
cks.append(t + pos)
tk_nums.append(tnum)
else:
cks[-1] += t + pos
tk_nums[-1] += tnum
for sec, pos in sections:
s, e = 0, 1
while e < len(sec):
if sec[e] in delimiter:
add_chunk(sec[s: e+1], pos)
s = e + 1
e = s + 1
else:
e += 1
if s < e: add_chunk(sec[s: e], pos)
return cks

View File

@ -21,6 +21,7 @@ import hashlib
import copy
import re
import sys
import traceback
from functools import partial
from timeit import default_timer as timer
@ -36,7 +37,7 @@ from rag.nlp import search
from io import BytesIO
import pandas as pd
from rag.app import laws, paper, presentation, manual, qa, table, book, resume
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture
from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService
@ -56,47 +57,31 @@ FACTORY = {
ParserType.QA.value: qa,
ParserType.TABLE.value: table,
ParserType.RESUME.value: resume,
ParserType.PICTURE.value: picture,
}
def set_progress(task_id, from_page=0, to_page=-1, prog=None, msg="Processing..."):
def set_progress(task_id, from_page=0, to_page=-1,
prog=None, msg="Processing..."):
if prog is not None and prog < 0:
msg = "[ERROR]"+msg
cancel = TaskService.do_cancel(task_id)
if cancel:
msg += " [Canceled]"
prog = -1
if to_page > 0: msg = f"Page({from_page}~{to_page}): " + msg
if to_page > 0:
msg = f"Page({from_page}~{to_page}): " + msg
d = {"progress_msg": msg}
if prog is not None: d["progress"] = prog
if prog is not None:
d["progress"] = prog
try:
TaskService.update_progress(task_id, d)
except Exception as e:
cron_logger.error("set_progress:({}), {}".format(task_id, str(e)))
if cancel:sys.exit()
"""
def chuck_doc(name, binary, tenant_id, cvmdl=None):
suff = os.path.split(name)[-1].lower().split(".")[-1]
if suff.find("pdf") >= 0:
return PDF(binary)
if suff.find("doc") >= 0:
return DOC(binary)
if re.match(r"(xlsx|xlsm|xltx|xltm)", suff):
return EXC(binary)
if suff.find("ppt") >= 0:
return PPT(binary)
if cvmdl and re.search(r"\.(jpg|jpeg|png|tif|gif|pcx|tga|exif|fpx|svg|psd|cdr|pcd|dxf|ufo|eps|ai|raw|WMF|webp|avif|apng|icon|ico)$",
name.lower()):
txt = cvmdl.describe(binary)
field = TextChunker.Fields()
field.text_chunks = [(txt, binary)]
field.table_chunks = []
return field
return TextChunker()(binary)
"""
if cancel:
sys.exit()
def collect(comm, mod, tm):
@ -109,29 +94,38 @@ def collect(comm, mod, tm):
return tasks
def build(row, cvmdl):
def build(row):
if row["size"] > DOC_MAXIMUM_SIZE:
set_progress(row["id"], prog=-1, msg="File size exceeds( <= %dMb )" %
(int(DOC_MAXIMUM_SIZE / 1024 / 1024)))
return []
callback = partial(set_progress, row["id"], row["from_page"], row["to_page"])
callback = partial(
set_progress,
row["id"],
row["from_page"],
row["to_page"])
chunker = FACTORY[row["parser_id"].lower()]
try:
cron_logger.info("Chunkking {}/{}".format(row["location"], row["name"]))
cks = chunker.chunk(row["name"], binary = MINIO.get(row["kb_id"], row["location"]), from_page=row["from_page"], to_page=row["to_page"],
callback = callback, kb_id=row["kb_id"], parser_config=row["parser_config"])
cron_logger.info(
"Chunkking {}/{}".format(row["location"], row["name"]))
cks = chunker.chunk(row["name"], binary=MINIO.get(row["kb_id"], row["location"]), from_page=row["from_page"],
to_page=row["to_page"], lang=row["language"], callback=callback,
kb_id=row["kb_id"], parser_config=row["parser_config"], tenant_id=row["tenant_id"])
except Exception as e:
if re.search("(No such file|not found)", str(e)):
callback(-1, "Can not find file <%s>" % row["doc_name"])
else:
callback(-1, f"Internal server error: %s" % str(e).replace("'", ""))
callback(-1, f"Internal server error: %s" %
str(e).replace("'", ""))
traceback.print_exc()
cron_logger.warn("Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
cron_logger.warn(
"Chunkking {}/{}: {}".format(row["location"], row["name"], str(e)))
return
callback(msg="Finished slicing files. Start to embedding the content.")
callback(msg="Finished slicing files(%d). Start to embedding the content."%len(cks))
docs = []
doc = {
@ -142,7 +136,8 @@ def build(row, cvmdl):
d = copy.deepcopy(doc)
d.update(ck)
md5 = hashlib.md5()
md5.update((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8"))
md5.update((ck["content_with_weight"] +
str(d["doc_id"])).encode("utf-8"))
d["_id"] = md5.hexdigest()
d["create_time"] = str(datetime.datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.datetime.now().timestamp()
@ -173,7 +168,8 @@ def init_kb(row):
def embedding(docs, mdl, parser_config={}):
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [d["content_with_weight"] for d in docs]
tts, cnts = [rmSpace(d["title_tks"]) for d in docs if d.get("title_tks")], [
d["content_with_weight"] for d in docs]
tk_count = 0
if len(tts) == len(cnts):
tts, c = mdl.encode(tts)
@ -182,7 +178,8 @@ def embedding(docs, mdl, parser_config={}):
cnts, c = mdl.encode(cnts)
tk_count += c
title_w = float(parser_config.get("filename_embd_weight", 0.1))
vects = (title_w * tts + (1-title_w) * cnts) if len(tts) == len(cnts) else cnts
vects = (title_w * tts + (1 - title_w) *
cnts) if len(tts) == len(cnts) else cnts
assert len(vects) == len(docs)
for i, d in enumerate(docs):
@ -192,7 +189,10 @@ def embedding(docs, mdl, parser_config={}):
def main(comm, mod):
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"{comm}-{mod}.tm")
tm_fnm = os.path.join(
get_project_base_directory(),
"rag/res",
f"{comm}-{mod}.tm")
tm = findMaxTm(tm_fnm)
rows = collect(comm, mod, tm)
if len(rows) == 0:
@ -203,15 +203,13 @@ def main(comm, mod):
callback = partial(set_progress, r["id"], r["from_page"], r["to_page"])
try:
embd_mdl = LLMBundle(r["tenant_id"], LLMType.EMBEDDING)
cv_mdl = LLMBundle(r["tenant_id"], LLMType.IMAGE2TEXT)
# TODO: sequence2text model
except Exception as e:
callback(prog=-1, msg=str(e))
continue
st_tm = timer()
cks = build(r, cv_mdl)
if cks is None:continue
cks = build(r)
if cks is None:
continue
if not cks:
tmf.write(str(r["update_time"]) + "\n")
callback(1., "No chunk! Done!")
@ -233,11 +231,15 @@ def main(comm, mod):
cron_logger.error(str(es_r))
else:
if TaskService.do_cancel(r["id"]):
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
ELASTICSEARCH.deleteByQuery(
Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
continue
callback(1., "Done!")
DocumentService.increment_chunk_num(r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))
DocumentService.increment_chunk_num(
r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
cron_logger.info(
"Chunk doc({}), token({}), chunks({})".format(
r["id"], tk_count, len(cks)))
tmf.write(str(r["update_time"]) + "\n")
tmf.close()