Feat: add vision LLM PDF parser (#6173)

### What problem does this PR solve?

Add vision LLM PDF parser

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
Yongteng Lei
2025-03-18 14:52:20 +08:00
committed by GitHub
parent 897fe85b5c
commit 5cf610af40
7 changed files with 413 additions and 102 deletions

View File

@ -15,13 +15,12 @@
# #
import logging import logging
from api.db.services.user_service import TenantService
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
from api import settings from api import settings
from api.db import LLMType from api.db import LLMType
from api.db.db_models import DB from api.db.db_models import DB, LLM, LLMFactories, TenantLLM
from api.db.db_models import LLMFactories, LLM, TenantLLM
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
from api.db.services.user_service import TenantService
from rag.llm import ChatModel, CvModel, EmbeddingModel, RerankModel, Seq2txtModel, TTSModel
class LLMFactoriesService(CommonService): class LLMFactoriesService(CommonService):
@ -266,6 +265,14 @@ class LLMBundle:
"LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens)) "LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
return txt return txt
def describe_with_prompt(self, image, prompt):
txt, used_tokens = self.mdl.describe_with_prompt(image, prompt)
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, used_tokens):
logging.error(
"LLMBundle.describe can't update token usage for {}/IMAGE2TEXT used_tokens: {}".format(self.tenant_id, used_tokens))
return txt
def transcription(self, audio): def transcription(self, audio):
txt, used_tokens = self.mdl.transcription(audio) txt, used_tokens = self.mdl.transcription(audio)
if not TenantLLMService.increase_usage( if not TenantLLMService.increase_usage(

View File

@ -17,26 +17,27 @@
import logging import logging
import os import os
import random import random
from timeit import default_timer as timer import re
import sys import sys
import threading import threading
import trio from copy import deepcopy
import xgboost as xgb
from io import BytesIO from io import BytesIO
import re from timeit import default_timer as timer
import pdfplumber
from PIL import Image
import numpy as np import numpy as np
import pdfplumber
import trio
import xgboost as xgb
from huggingface_hub import snapshot_download
from PIL import Image
from pypdf import PdfReader as pdf2_read from pypdf import PdfReader as pdf2_read
from api import settings from api import settings
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from deepdoc.vision import OCR, Recognizer, LayoutRecognizer, TableStructureRecognizer from deepdoc.vision import OCR, LayoutRecognizer, Recognizer, TableStructureRecognizer
from rag.app.picture import vision_llm_chunk as picture_vision_llm_chunk
from rag.nlp import rag_tokenizer from rag.nlp import rag_tokenizer
from copy import deepcopy from rag.prompts import vision_llm_describe_prompt
from huggingface_hub import snapshot_download
from rag.settings import PARALLEL_DEVICES from rag.settings import PARALLEL_DEVICES
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber" LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
@ -45,7 +46,7 @@ if LOCK_KEY_pdfplumber not in sys.modules:
class RAGFlowPdfParser: class RAGFlowPdfParser:
def __init__(self): def __init__(self, **kwargs):
""" """
If you have trouble downloading HuggingFace models, -_^ this might help!! If you have trouble downloading HuggingFace models, -_^ this might help!!
@ -57,12 +58,12 @@ class RAGFlowPdfParser:
^_- ^_-
""" """
self.ocr = OCR() self.ocr = OCR()
self.parallel_limiter = None self.parallel_limiter = None
if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 1: if PARALLEL_DEVICES is not None and PARALLEL_DEVICES > 1:
self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)] self.parallel_limiter = [trio.CapacityLimiter(1) for _ in range(PARALLEL_DEVICES)]
if hasattr(self, "model_speciess"): if hasattr(self, "model_speciess"):
self.layouter = LayoutRecognizer("layout." + self.model_speciess) self.layouter = LayoutRecognizer("layout." + self.model_speciess)
else: else:
@ -106,7 +107,7 @@ class RAGFlowPdfParser:
def _y_dis( def _y_dis(
self, a, b): self, a, b):
return ( return (
b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2 b["top"] + b["bottom"] - a["top"] - a["bottom"]) / 2
def _match_proj(self, b): def _match_proj(self, b):
proj_patt = [ proj_patt = [
@ -129,9 +130,9 @@ class RAGFlowPdfParser:
tks_down = rag_tokenizer.tokenize(down["text"][:LEN]).split() tks_down = rag_tokenizer.tokenize(down["text"][:LEN]).split()
tks_up = rag_tokenizer.tokenize(up["text"][-LEN:]).split() tks_up = rag_tokenizer.tokenize(up["text"][-LEN:]).split()
tks_all = up["text"][-LEN:].strip() \ tks_all = up["text"][-LEN:].strip() \
+ (" " if re.match(r"[a-zA-Z0-9]+", + (" " if re.match(r"[a-zA-Z0-9]+",
up["text"][-1] + down["text"][0]) else "") \ up["text"][-1] + down["text"][0]) else "") \
+ down["text"][:LEN].strip() + down["text"][:LEN].strip()
tks_all = rag_tokenizer.tokenize(tks_all).split() tks_all = rag_tokenizer.tokenize(tks_all).split()
fea = [ fea = [
up.get("R", -1) == down.get("R", -1), up.get("R", -1) == down.get("R", -1),
@ -153,7 +154,7 @@ class RAGFlowPdfParser:
True if re.search(r"[,][^。.]+$", up["text"]) else False, True if re.search(r"[,][^。.]+$", up["text"]) else False,
True if re.search(r"[,][^。.]+$", up["text"]) else False, True if re.search(r"[,][^。.]+$", up["text"]) else False,
True if re.search(r"[\(][^\)]+$", up["text"]) True if re.search(r"[\(][^\)]+$", up["text"])
and re.search(r"[\)]", down["text"]) else False, and re.search(r"[\)]", down["text"]) else False,
self._match_proj(down), self._match_proj(down),
True if re.match(r"[A-Z]", down["text"]) else False, True if re.match(r"[A-Z]", down["text"]) else False,
True if re.match(r"[A-Z]", up["text"][-1]) else False, True if re.match(r"[A-Z]", up["text"][-1]) else False,
@ -215,7 +216,7 @@ class RAGFlowPdfParser:
continue continue
for tb in tbls: # for table for tb in tbls: # for table
left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \ left, top, right, bott = tb["x0"] - MARGIN, tb["top"] - MARGIN, \
tb["x1"] + MARGIN, tb["bottom"] + MARGIN tb["x1"] + MARGIN, tb["bottom"] + MARGIN
left *= ZM left *= ZM
top *= ZM top *= ZM
right *= ZM right *= ZM
@ -309,7 +310,7 @@ class RAGFlowPdfParser:
"page_number": pagenum} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]], "page_number": pagenum} for b, t in bxs if b[0][0] <= b[1][0] and b[0][1] <= b[-1][1]],
self.mean_height[-1] / 3 self.mean_height[-1] / 3
) )
# merge chars in the same rect # merge chars in the same rect
for c in Recognizer.sort_Y_firstly( for c in Recognizer.sort_Y_firstly(
chars, self.mean_height[pagenum - 1] // 4): chars, self.mean_height[pagenum - 1] // 4):
@ -457,7 +458,7 @@ class RAGFlowPdfParser:
b_["text"], b_["text"],
any(feats), any(feats),
any(concatting_feats), any(concatting_feats),
)) ))
i += 1 i += 1
continue continue
# merge up and down # merge up and down
@ -665,7 +666,7 @@ class RAGFlowPdfParser:
i += 1 i += 1
continue continue
lout_no = str(self.boxes[i]["page_number"]) + \ lout_no = str(self.boxes[i]["page_number"]) + \
"-" + str(self.boxes[i]["layoutno"]) "-" + str(self.boxes[i]["layoutno"])
if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption", if TableStructureRecognizer.is_caption(self.boxes[i]) or self.boxes[i]["layout_type"] in ["table caption",
"title", "title",
"figure caption", "figure caption",
@ -968,7 +969,7 @@ class RAGFlowPdfParser:
fnm) if not binary else pdfplumber.open(BytesIO(binary)) fnm) if not binary else pdfplumber.open(BytesIO(binary))
total_page = len(pdf.pages) total_page = len(pdf.pages)
pdf.close() pdf.close()
return total_page return total_page
except Exception: except Exception:
logging.exception("total_page_number") logging.exception("total_page_number")
@ -994,7 +995,7 @@ class RAGFlowPdfParser:
except Exception as e: except Exception as e:
logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}") logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}")
self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead. self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead.
self.total_page = len(self.pdf.pages) self.total_page = len(self.pdf.pages)
except Exception: except Exception:
logging.exception("RAGFlowPdfParser __images__") logging.exception("RAGFlowPdfParser __images__")
@ -1023,7 +1024,7 @@ class RAGFlowPdfParser:
logging.debug("Images converted.") logging.debug("Images converted.")
self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join( self.is_english = [re.search(r"[a-zA-Z0-9,/¸;:'\[\]\(\)!@#$%^&*\"?<>._-]{30,}", "".join(
random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in random.choices([c["text"] for c in self.page_chars[i]], k=min(100, len(self.page_chars[i]))))) for i in
range(len(self.page_chars))] range(len(self.page_chars))]
if sum([1 if e else 0 for e in self.is_english]) > len( if sum([1 if e else 0 for e in self.is_english]) > len(
self.page_images) / 2: self.page_images) / 2:
self.is_english = True self.is_english = True
@ -1036,7 +1037,7 @@ class RAGFlowPdfParser:
if chars[j]["text"] and chars[j + 1]["text"] \ if chars[j]["text"] and chars[j + 1]["text"] \
and re.match(r"[0-9a-zA-Z,.:;!%]+", chars[j]["text"] + chars[j + 1]["text"]) \ and re.match(r"[0-9a-zA-Z,.:;!%]+", chars[j]["text"] + chars[j + 1]["text"]) \
and chars[j + 1]["x0"] - chars[j]["x1"] >= min(chars[j + 1]["width"], and chars[j + 1]["x0"] - chars[j]["x1"] >= min(chars[j + 1]["width"],
chars[j]["width"]) / 2: chars[j]["width"]) / 2:
chars[j]["text"] += " " chars[j]["text"] += " "
j += 1 j += 1
@ -1045,7 +1046,7 @@ class RAGFlowPdfParser:
await trio.to_thread.run_sync(lambda: self.__ocr(i + 1, img, chars, zoomin, id)) await trio.to_thread.run_sync(lambda: self.__ocr(i + 1, img, chars, zoomin, id))
else: else:
self.__ocr(i + 1, img, chars, zoomin, id) self.__ocr(i + 1, img, chars, zoomin, id)
if callback and i % 6 == 5: if callback and i % 6 == 5:
callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="") callback(prog=(i + 1) * 0.6 / len(self.page_images), msg="")
@ -1060,14 +1061,14 @@ class RAGFlowPdfParser:
) )
self.page_cum_height.append(img.size[1] / zoomin) self.page_cum_height.append(img.size[1] / zoomin)
return chars return chars
if self.parallel_limiter: if self.parallel_limiter:
async with trio.open_nursery() as nursery: async with trio.open_nursery() as nursery:
for i, img in enumerate(self.page_images): for i, img in enumerate(self.page_images):
chars = __ocr_preprocess() chars = __ocr_preprocess()
nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars, nursery.start_soon(__img_ocr, i, i % PARALLEL_DEVICES, img, chars,
self.parallel_limiter[i % PARALLEL_DEVICES]) self.parallel_limiter[i % PARALLEL_DEVICES])
await trio.sleep(0.1) await trio.sleep(0.1)
else: else:
for i, img in enumerate(self.page_images): for i, img in enumerate(self.page_images):
@ -1075,9 +1076,9 @@ class RAGFlowPdfParser:
await __img_ocr(i, 0, img, chars, None) await __img_ocr(i, 0, img, chars, None)
start = timer() start = timer()
trio.run(__img_ocr_launcher) trio.run(__img_ocr_launcher)
logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s") logging.info(f"__images__ {len(self.page_images)} pages cost {timer() - start}s")
if not self.is_english and not any( if not self.is_english and not any(
@ -1142,7 +1143,7 @@ class RAGFlowPdfParser:
self.page_images[pns[0]].crop((left * ZM, top * ZM, self.page_images[pns[0]].crop((left * ZM, top * ZM,
right * right *
ZM, min( ZM, min(
bottom, self.page_images[pns[0]].size[1]) bottom, self.page_images[pns[0]].size[1])
)) ))
) )
if 0 < ii < len(poss) - 1: if 0 < ii < len(poss) - 1:
@ -1240,5 +1241,52 @@ class PlainParser:
raise NotImplementedError raise NotImplementedError
class VisionParser(RAGFlowPdfParser):
def __init__(self, vision_model, *args, **kwargs):
super().__init__(*args, **kwargs)
self.vision_model = vision_model
def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None):
try:
with sys.modules[LOCK_KEY_pdfplumber]:
self.pdf = pdfplumber.open(fnm) if isinstance(
fnm, str) else pdfplumber.open(BytesIO(fnm))
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
enumerate(self.pdf.pages[page_from:page_to])]
self.total_page = len(self.pdf.pages)
except Exception:
self.page_images = None
self.total_page = 0
logging.exception("VisionParser __images__")
def __call__(self, filename, from_page=0, to_page=100000, **kwargs):
callback = kwargs.get("callback", lambda prog, msg: None)
self.__images__(fnm=filename, zoomin=3, page_from=from_page, page_to=to_page, **kwargs)
total_pdf_pages = self.total_page
start_page = max(0, from_page)
end_page = min(to_page, total_pdf_pages)
all_docs = []
for idx, img_binary in enumerate(self.page_images or []):
pdf_page_num = idx # 0-based
if pdf_page_num < start_page or pdf_page_num >= end_page:
continue
docs = picture_vision_llm_chunk(
binary=img_binary,
vision_model=self.vision_model,
prompt=vision_llm_describe_prompt(page=pdf_page_num+1),
callback=callback,
)
if docs:
all_docs.append(docs)
return [(doc, "") for doc in all_docs], []
if __name__ == "__main__": if __name__ == "__main__":
pass pass

View File

@ -26,8 +26,10 @@ from markdown import markdown
from PIL import Image from PIL import Image
from tika import parser from tika import parser
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownParser, PdfParser, TxtParser from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownParser, PdfParser, TxtParser
from deepdoc.parser.pdf_parser import PlainParser from deepdoc.parser.pdf_parser import PlainParser, VisionParser
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_docx, tokenize_table from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_docx, tokenize_table
from rag.utils import num_tokens_from_string from rag.utils import num_tokens_from_string
@ -237,9 +239,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
return res return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE): elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf() layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
if layout_recognizer == "DeepDOC":
pdf_parser = Pdf()
elif layout_recognizer == "Plain Text":
pdf_parser = PlainParser() pdf_parser = PlainParser()
else:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang)
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page, sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
callback=callback) callback=callback)
res = tokenize_table(tables, doc, is_english) res = tokenize_table(tables, doc, is_english)

View File

@ -21,8 +21,9 @@ from PIL import Image
from api.db import LLMType from api.db import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from rag.nlp import tokenize
from deepdoc.vision import OCR from deepdoc.vision import OCR
from rag.nlp import tokenize
from rag.utils import clean_markdown_block
ocr = OCR() ocr = OCR()
@ -57,3 +58,32 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
callback(prog=-1, msg=str(e)) callback(prog=-1, msg=str(e))
return [] return []
def vision_llm_chunk(binary, vision_model, prompt=None, callback=None):
"""
A simple wrapper to process image to markdown texts via VLM.
Returns:
Simple markdown texts generated by VLM.
"""
callback = callback or (lambda prog, msg: None)
img = binary
txt = ""
try:
img_binary = io.BytesIO()
img.save(img_binary, format='JPEG')
img_binary.seek(0)
ans = clean_markdown_block(vision_model.describe_with_prompt(img_binary.read(), prompt))
txt += "\n" + ans
return txt
except Exception as e:
callback(-1, str(e))
return []

View File

@ -13,31 +13,36 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
import io
from abc import ABC
from ollama import Client
from PIL import Image
from openai import OpenAI
import os
import base64 import base64
from io import BytesIO import io
import json import json
import requests import os
from abc import ABC
from io import BytesIO
import requests
from ollama import Client
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from PIL import Image
from zhipuai import ZhipuAI
from rag.nlp import is_english
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
from rag.nlp import is_english
from rag.prompts import vision_llm_describe_prompt
class Base(ABC): class Base(ABC):
def __init__(self, key, model_name): def __init__(self, key, model_name):
pass pass
def describe(self, image, max_tokens=300): def describe(self, image):
raise NotImplementedError("Please implement encode method!") raise NotImplementedError("Please implement encode method!")
def describe_with_prompt(self, image, prompt=None):
raise NotImplementedError("Please implement encode method!")
def chat(self, system, history, gen_conf, image=""): def chat(self, system, history, gen_conf, image=""):
if system: if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@ -90,7 +95,7 @@ class Base(ABC):
yield ans + "\n**ERROR**: " + str(e) yield ans + "\n**ERROR**: " + str(e)
yield tk_count yield tk_count
def image2base64(self, image): def image2base64(self, image):
if isinstance(image, bytes): if isinstance(image, bytes):
return base64.b64encode(image).decode("utf-8") return base64.b64encode(image).decode("utf-8")
@ -122,6 +127,25 @@ class Base(ABC):
} }
] ]
def vision_llm_prompt(self, b64, prompt=None):
return [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64}"
},
},
{
"type": "text",
"text": prompt if prompt else vision_llm_describe_prompt(),
},
],
}
]
def chat_prompt(self, text, b64): def chat_prompt(self, text, b64):
return [ return [
{ {
@ -140,12 +164,12 @@ class Base(ABC):
class GptV4(Base): class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"): def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
if not base_url: if not base_url:
base_url="https://api.openai.com/v1" base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url) self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name self.model_name = model_name
self.lang = lang self.lang = lang
def describe(self, image, max_tokens=300): def describe(self, image):
b64 = self.image2base64(image) b64 = self.image2base64(image)
prompt = self.prompt(b64) prompt = self.prompt(b64)
for i in range(len(prompt)): for i in range(len(prompt)):
@ -159,6 +183,16 @@ class GptV4(Base):
) )
return res.choices[0].message.content.strip(), res.usage.total_tokens return res.choices[0].message.content.strip(), res.usage.total_tokens
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
res = self.client.chat.completions.create(
model=self.model_name,
messages=vision_prompt,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
class AzureGptV4(Base): class AzureGptV4(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs): def __init__(self, key, model_name, lang="Chinese", **kwargs):
@ -168,7 +202,7 @@ class AzureGptV4(Base):
self.model_name = model_name self.model_name = model_name
self.lang = lang self.lang = lang
def describe(self, image, max_tokens=300): def describe(self, image):
b64 = self.image2base64(image) b64 = self.image2base64(image)
prompt = self.prompt(b64) prompt = self.prompt(b64)
for i in range(len(prompt)): for i in range(len(prompt)):
@ -182,6 +216,16 @@ class AzureGptV4(Base):
) )
return res.choices[0].message.content.strip(), res.usage.total_tokens return res.choices[0].message.content.strip(), res.usage.total_tokens
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
res = self.client.chat.completions.create(
model=self.model_name,
messages=vision_prompt,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
class QWenCV(Base): class QWenCV(Base):
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs): def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
@ -212,23 +256,57 @@ class QWenCV(Base):
} }
] ]
def vision_llm_prompt(self, binary, prompt=None):
# stupid as hell
tmp_dir = get_project_base_directory("tmp")
if not os.path.exists(tmp_dir):
os.mkdir(tmp_dir)
path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
Image.open(io.BytesIO(binary)).save(path)
return [
{
"role": "user",
"content": [
{
"image": f"file://{path}"
},
{
"text": prompt if prompt else vision_llm_describe_prompt(),
},
],
}
]
def chat_prompt(self, text, b64): def chat_prompt(self, text, b64):
return [ return [
{"image": f"{b64}"}, {"image": f"{b64}"},
{"text": text}, {"text": text},
] ]
def describe(self, image, max_tokens=300): def describe(self, image):
from http import HTTPStatus from http import HTTPStatus
from dashscope import MultiModalConversation from dashscope import MultiModalConversation
response = MultiModalConversation.call(model=self.model_name, response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image))
messages=self.prompt(image)) if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
return response.message, 0
def describe_with_prompt(self, image, prompt=None):
from http import HTTPStatus
from dashscope import MultiModalConversation
vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image)
response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt)
if response.status_code == HTTPStatus.OK: if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
return response.message, 0 return response.message, 0
def chat(self, system, history, gen_conf, image=""): def chat(self, system, history, gen_conf, image=""):
from http import HTTPStatus from http import HTTPStatus
from dashscope import MultiModalConversation from dashscope import MultiModalConversation
if system: if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@ -254,6 +332,7 @@ class QWenCV(Base):
def chat_streamly(self, system, history, gen_conf, image=""): def chat_streamly(self, system, history, gen_conf, image=""):
from http import HTTPStatus from http import HTTPStatus
from dashscope import MultiModalConversation from dashscope import MultiModalConversation
if system: if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@ -292,15 +371,25 @@ class Zhipu4V(Base):
self.model_name = model_name self.model_name = model_name
self.lang = lang self.lang = lang
def describe(self, image, max_tokens=1024): def describe(self, image):
b64 = self.image2base64(image) b64 = self.image2base64(image)
prompt = self.prompt(b64) prompt = self.prompt(b64)
prompt[0]["content"][1]["type"] = "text" prompt[0]["content"][1]["type"] = "text"
res = self.client.chat.completions.create( res = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=prompt messages=prompt,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
res = self.client.chat.completions.create(
model=self.model_name,
messages=vision_prompt
) )
return res.choices[0].message.content.strip(), res.usage.total_tokens return res.choices[0].message.content.strip(), res.usage.total_tokens
@ -334,7 +423,7 @@ class Zhipu4V(Base):
his["content"] = self.chat_prompt(his["content"], image) his["content"] = self.chat_prompt(his["content"], image)
response = self.client.chat.completions.create( response = self.client.chat.completions.create(
model=self.model_name, model=self.model_name,
messages=history, messages=history,
temperature=gen_conf.get("temperature", 0.3), temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7), top_p=gen_conf.get("top_p", 0.7),
@ -364,7 +453,7 @@ class OllamaCV(Base):
self.model_name = model_name self.model_name = model_name
self.lang = lang self.lang = lang
def describe(self, image, max_tokens=1024): def describe(self, image):
prompt = self.prompt("") prompt = self.prompt("")
try: try:
response = self.client.generate( response = self.client.generate(
@ -377,6 +466,19 @@ class OllamaCV(Base):
except Exception as e: except Exception as e:
return "**ERROR**: " + str(e), 0 return "**ERROR**: " + str(e), 0
def describe_with_prompt(self, image, prompt=None):
vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
try:
response = self.client.generate(
model=self.model_name,
prompt=vision_prompt[0]["content"][1]["text"],
images=[image],
)
ans = response["response"].strip()
return ans, 128
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat(self, system, history, gen_conf, image=""): def chat(self, system, history, gen_conf, image=""):
if system: if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"] history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@ -460,7 +562,7 @@ class XinferenceCV(Base):
self.model_name = model_name self.model_name = model_name
self.lang = lang self.lang = lang
def describe(self, image, max_tokens=300): def describe(self, image):
b64 = self.image2base64(image) b64 = self.image2base64(image)
res = self.client.chat.completions.create( res = self.client.chat.completions.create(
@ -469,27 +571,49 @@ class XinferenceCV(Base):
) )
return res.choices[0].message.content.strip(), res.usage.total_tokens return res.choices[0].message.content.strip(), res.usage.total_tokens
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
res = self.client.chat.completions.create(
model=self.model_name,
messages=vision_prompt,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
class GeminiCV(Base): class GeminiCV(Base):
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs): def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
from google.generativeai import client, GenerativeModel from google.generativeai import GenerativeModel, client
client.configure(api_key=key) client.configure(api_key=key)
_client = client.get_default_generative_client() _client = client.get_default_generative_client()
self.model_name = model_name self.model_name = model_name
self.model = GenerativeModel(model_name=self.model_name) self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client self.model._client = _client
self.lang = lang self.lang = lang
def describe(self, image, max_tokens=2048): def describe(self, image):
from PIL.Image import open from PIL.Image import open
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \ prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out." "Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
b64 = self.image2base64(image) b64 = self.image2base64(image)
img = open(BytesIO(base64.b64decode(b64))) img = open(BytesIO(base64.b64decode(b64)))
input = [prompt,img] input = [prompt, img]
res = self.model.generate_content( res = self.model.generate_content(
input input
) )
return res.text,res.usage_metadata.total_token_count return res.text, res.usage_metadata.total_token_count
def describe_with_prompt(self, image, prompt=None):
from PIL.Image import open
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
img = open(BytesIO(base64.b64decode(b64)))
input = [vision_prompt, img]
res = self.model.generate_content(
input,
)
return res.text, res.usage_metadata.total_token_count
def chat(self, system, history, gen_conf, image=""): def chat(self, system, history, gen_conf, image=""):
from transformers import GenerationConfig from transformers import GenerationConfig
@ -566,7 +690,7 @@ class LocalCV(Base):
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs): def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
pass pass
def describe(self, image, max_tokens=1024): def describe(self, image):
return "", 0 return "", 0
@ -590,7 +714,7 @@ class NvidiaCV(Base):
) )
self.key = key self.key = key
def describe(self, image, max_tokens=1024): def describe(self, image):
b64 = self.image2base64(image) b64 = self.image2base64(image)
response = requests.post( response = requests.post(
url=self.base_url, url=self.base_url,
@ -609,6 +733,27 @@ class NvidiaCV(Base):
response["usage"]["total_tokens"], response["usage"]["total_tokens"],
) )
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
response = requests.post(
url=self.base_url,
headers={
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {self.key}",
},
json={
"messages": vision_prompt,
},
)
response = response.json()
return (
response["choices"][0]["message"]["content"].strip(),
response["usage"]["total_tokens"],
)
def prompt(self, b64): def prompt(self, b64):
return [ return [
{ {
@ -622,6 +767,17 @@ class NvidiaCV(Base):
} }
] ]
def vision_llm_prompt(self, b64, prompt=None):
return [
{
"role": "user",
"content": (
prompt if prompt else vision_llm_describe_prompt()
)
+ f' <img src="data:image/jpeg;base64,{b64}"/>',
}
]
def chat_prompt(self, text, b64): def chat_prompt(self, text, b64):
return [ return [
{ {
@ -634,7 +790,7 @@ class NvidiaCV(Base):
class StepFunCV(GptV4): class StepFunCV(GptV4):
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"): def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
if not base_url: if not base_url:
base_url="https://api.stepfun.com/v1" base_url = "https://api.stepfun.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url) self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name self.model_name = model_name
self.lang = lang self.lang = lang
@ -666,18 +822,18 @@ class TogetherAICV(GptV4):
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"): def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
if not base_url: if not base_url:
base_url = "https://api.together.xyz/v1" base_url = "https://api.together.xyz/v1"
super().__init__(key, model_name,lang,base_url) super().__init__(key, model_name, lang, base_url)
class YiCV(GptV4): class YiCV(GptV4):
def __init__(self, key, model_name, lang="Chinese",base_url="https://api.lingyiwanwu.com/v1",): def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",):
if not base_url: if not base_url:
base_url = "https://api.lingyiwanwu.com/v1" base_url = "https://api.lingyiwanwu.com/v1"
super().__init__(key, model_name,lang,base_url) super().__init__(key, model_name, lang, base_url)
class HunyuanCV(Base): class HunyuanCV(Base):
def __init__(self, key, model_name, lang="Chinese",base_url=None): def __init__(self, key, model_name, lang="Chinese", base_url=None):
from tencentcloud.common import credential from tencentcloud.common import credential
from tencentcloud.hunyuan.v20230901 import hunyuan_client from tencentcloud.hunyuan.v20230901 import hunyuan_client
@ -689,12 +845,12 @@ class HunyuanCV(Base):
self.client = hunyuan_client.HunyuanClient(cred, "") self.client = hunyuan_client.HunyuanClient(cred, "")
self.lang = lang self.lang = lang
def describe(self, image, max_tokens=4096): def describe(self, image):
from tencentcloud.hunyuan.v20230901 import models
from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException, TencentCloudSDKException,
) )
from tencentcloud.hunyuan.v20230901 import models
b64 = self.image2base64(image) b64 = self.image2base64(image)
req = models.ChatCompletionsRequest() req = models.ChatCompletionsRequest()
params = {"Model": self.model_name, "Messages": self.prompt(b64)} params = {"Model": self.model_name, "Messages": self.prompt(b64)}
@ -706,7 +862,24 @@ class HunyuanCV(Base):
return ans, response.Usage.TotalTokens return ans, response.Usage.TotalTokens
except TencentCloudSDKException as e: except TencentCloudSDKException as e:
return ans + "\n**ERROR**: " + str(e), 0 return ans + "\n**ERROR**: " + str(e), 0
def describe_with_prompt(self, image, prompt=None):
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
from tencentcloud.hunyuan.v20230901 import models
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
req = models.ChatCompletionsRequest()
params = {"Model": self.model_name, "Messages": vision_prompt}
req.from_json_string(json.dumps(params))
ans = ""
try:
response = self.client.ChatCompletions(req)
ans = response.Choices[0].Message.Content
return ans, response.Usage.TotalTokens
except TencentCloudSDKException as e:
return ans + "\n**ERROR**: " + str(e), 0
def prompt(self, b64): def prompt(self, b64):
return [ return [
{ {
@ -725,4 +898,4 @@ class HunyuanCV(Base):
}, },
], ],
} }
] ]

View File

@ -18,13 +18,13 @@ import json
import logging import logging
import re import re
from collections import defaultdict from collections import defaultdict
import json_repair import json_repair
from api import settings from api import settings
from api.db import LLMType from api.db import LLMType
from api.db.services.document_service import DocumentService
from api.db.services.llm_service import TenantLLMService, LLMBundle
from rag.settings import TAG_FLD from rag.settings import TAG_FLD
from rag.utils import num_tokens_from_string, encoder from rag.utils import encoder, num_tokens_from_string
def chunks_format(reference): def chunks_format(reference):
@ -44,9 +44,11 @@ def chunks_format(reference):
def llm_id2llm_type(llm_id): def llm_id2llm_type(llm_id):
from api.db.services.llm_service import TenantLLMService
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id) llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
llm_factories = settings.FACTORY_LLM_INFOS llm_factories = settings.FACTORY_LLM_INFOS
for llm_factory in llm_factories: for llm_factory in llm_factories:
for llm in llm_factory["llm"]: for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]: if llm_id == llm["llm_name"]:
@ -92,6 +94,8 @@ def message_fit_in(msg, max_length=4000):
def kb_prompt(kbinfos, max_tokens): def kb_prompt(kbinfos, max_tokens):
from api.db.services.document_service import DocumentService
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
used_token_count = 0 used_token_count = 0
chunks_num = 0 chunks_num = 0
@ -166,15 +170,15 @@ Overall, while Musk enjoys Dogecoin and often promotes it, he also warns against
def keyword_extraction(chat_mdl, content, topn=3): def keyword_extraction(chat_mdl, content, topn=3):
prompt = f""" prompt = f"""
Role: You're a text analyzer. Role: You're a text analyzer.
Task: extract the most important keywords/phrases of a given piece of text content. Task: extract the most important keywords/phrases of a given piece of text content.
Requirements: Requirements:
- Summarize the text content, and give top {topn} important keywords/phrases. - Summarize the text content, and give top {topn} important keywords/phrases.
- The keywords MUST be in language of the given piece of text content. - The keywords MUST be in language of the given piece of text content.
- The keywords are delimited by ENGLISH COMMA. - The keywords are delimited by ENGLISH COMMA.
- Keywords ONLY in output. - Keywords ONLY in output.
### Text Content ### Text Content
{content} {content}
""" """
@ -194,9 +198,9 @@ Requirements:
def question_proposal(chat_mdl, content, topn=3): def question_proposal(chat_mdl, content, topn=3):
prompt = f""" prompt = f"""
Role: You're a text analyzer. Role: You're a text analyzer.
Task: propose {topn} questions about a given piece of text content. Task: propose {topn} questions about a given piece of text content.
Requirements: Requirements:
- Understand and summarize the text content, and propose top {topn} important questions. - Understand and summarize the text content, and propose top {topn} important questions.
- The questions SHOULD NOT have overlapping meanings. - The questions SHOULD NOT have overlapping meanings.
- The questions SHOULD cover the main content of the text as much as possible. - The questions SHOULD cover the main content of the text as much as possible.
@ -204,7 +208,7 @@ Requirements:
- One question per line. - One question per line.
- Question ONLY in output. - Question ONLY in output.
### Text Content ### Text Content
{content} {content}
""" """
@ -223,6 +227,8 @@ Requirements:
def full_question(tenant_id, llm_id, messages, language=None): def full_question(tenant_id, llm_id, messages, language=None):
from api.db.services.llm_service import LLMBundle
if llm_id2llm_type(llm_id) == "image2text": if llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id) chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else: else:
@ -239,7 +245,7 @@ def full_question(tenant_id, llm_id, messages, language=None):
prompt = f""" prompt = f"""
Role: A helpful assistant Role: A helpful assistant
Task and steps: Task and steps:
1. Generate a full user question that would follow the conversation. 1. Generate a full user question that would follow the conversation.
2. If the user's question involves relative date, you need to convert it into absolute date based on the current date, which is {today}. For example: 'yesterday' would be converted to {yesterday}. 2. If the user's question involves relative date, you need to convert it into absolute date based on the current date, which is {today}. For example: 'yesterday' would be converted to {yesterday}.
@ -300,11 +306,11 @@ Output: What's the weather in Rochester on {tomorrow}?
def content_tagging(chat_mdl, content, all_tags, examples, topn=3): def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
prompt = f""" prompt = f"""
Role: You're a text analyzer. Role: You're a text analyzer.
Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set. Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set.
Steps:: Steps::
- Comprehend the tag/label set. - Comprehend the tag/label set.
- Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON. - Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON.
- Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score. - Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score.
@ -358,3 +364,32 @@ Output:
except Exception as e: except Exception as e:
logging.exception(f"JSON parsing error: {result} -> {e}") logging.exception(f"JSON parsing error: {result} -> {e}")
raise e raise e
def vision_llm_describe_prompt(page=None) -> str:
prompt_en = """
INSTRUCTION:
Transcribe the content from the provided PDF page image into clean Markdown format.
- Only output the content transcribed from the image.
- Do NOT output this instruction or any other explanation.
- If the content is missing or you do not understand the input, return an empty string.
RULES:
1. Do NOT generate examples, demonstrations, or templates.
2. Do NOT output any extra text such as 'Example', 'Example Output', or similar.
3. Do NOT generate any tables, headings, or content that is not explicitly present in the image.
4. Transcribe content word-for-word. Do NOT modify, translate, or omit any content.
5. Do NOT explain Markdown or mention that you are using Markdown.
6. Do NOT wrap the output in ```markdown or ``` blocks.
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
8. Preserve the original language, information, and order exactly as shown in the image.
"""
if page is not None:
prompt_en += f"\nAt the end of the transcription, add the page divider: `--- Page {page} ---`."
prompt_en += """
FAILURE HANDLING:
- If you do not detect valid content in the image, return an empty string.
"""
return prompt_en

View File

@ -16,7 +16,9 @@
import os import os
import re import re
import tiktoken import tiktoken
from api.utils.file_utils import get_project_base_directory from api.utils.file_utils import get_project_base_directory
@ -54,7 +56,7 @@ def findMaxDt(fnm):
pass pass
return m return m
def findMaxTm(fnm): def findMaxTm(fnm):
m = 0 m = 0
try: try:
@ -91,11 +93,18 @@ def truncate(string: str, max_len: int) -> str:
"""Returns truncated text if the length of text exceed max_len.""" """Returns truncated text if the length of text exceed max_len."""
return encoder.decode(encoder.encode(string)[:max_len]) return encoder.decode(encoder.encode(string)[:max_len])
def clean_markdown_block(text):
text = re.sub(r'^\s*```markdown\s*\n?', '', text)
text = re.sub(r'\n?\s*```\s*$', '', text)
return text.strip()
def get_float(v: str | None): def get_float(v: str | None):
if v is None: if v is None:
return float('-inf') return float('-inf')
try: try:
return float(v) return float(v)
except Exception: except Exception:
return float('-inf') return float('-inf')