Feat: add vision LLM PDF parser (#6173)

### What problem does this PR solve?

Add vision LLM PDF parser

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
Yongteng Lei
2025-03-18 14:52:20 +08:00
committed by GitHub
parent 897fe85b5c
commit 5cf610af40
7 changed files with 413 additions and 102 deletions

View File

@ -26,8 +26,10 @@ from markdown import markdown
from PIL import Image
from tika import parser
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from deepdoc.parser import DocxParser, ExcelParser, HtmlParser, JsonParser, MarkdownParser, PdfParser, TxtParser
from deepdoc.parser.pdf_parser import PlainParser
from deepdoc.parser.pdf_parser import PlainParser, VisionParser
from rag.nlp import concat_img, find_codec, naive_merge, naive_merge_docx, rag_tokenizer, tokenize_chunks, tokenize_chunks_docx, tokenize_table
from rag.utils import num_tokens_from_string
@ -237,9 +239,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
return res
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
layout_recognizer = parser_config.get("layout_recognize", "DeepDOC")
if layout_recognizer == "DeepDOC":
pdf_parser = Pdf()
elif layout_recognizer == "Plain Text":
pdf_parser = PlainParser()
else:
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT, llm_name=layout_recognizer, lang=lang)
pdf_parser = VisionParser(vision_model=vision_model, **kwargs)
sections, tables = pdf_parser(filename if not binary else binary, from_page=from_page, to_page=to_page,
callback=callback)
res = tokenize_table(tables, doc, is_english)

View File

@ -21,8 +21,9 @@ from PIL import Image
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from rag.nlp import tokenize
from deepdoc.vision import OCR
from rag.nlp import tokenize
from rag.utils import clean_markdown_block
ocr = OCR()
@ -57,3 +58,32 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
callback(prog=-1, msg=str(e))
return []
def vision_llm_chunk(binary, vision_model, prompt=None, callback=None):
"""
A simple wrapper to process image to markdown texts via VLM.
Returns:
Simple markdown texts generated by VLM.
"""
callback = callback or (lambda prog, msg: None)
img = binary
txt = ""
try:
img_binary = io.BytesIO()
img.save(img_binary, format='JPEG')
img_binary.seek(0)
ans = clean_markdown_block(vision_model.describe_with_prompt(img_binary.read(), prompt))
txt += "\n" + ans
return txt
except Exception as e:
callback(-1, str(e))
return []

View File

@ -13,31 +13,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
import io
from abc import ABC
from ollama import Client
from PIL import Image
from openai import OpenAI
import os
import base64
from io import BytesIO
import io
import json
import requests
import os
from abc import ABC
from io import BytesIO
import requests
from ollama import Client
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from PIL import Image
from zhipuai import ZhipuAI
from rag.nlp import is_english
from api.utils import get_uuid
from api.utils.file_utils import get_project_base_directory
from rag.nlp import is_english
from rag.prompts import vision_llm_describe_prompt
class Base(ABC):
def __init__(self, key, model_name):
pass
def describe(self, image, max_tokens=300):
def describe(self, image):
raise NotImplementedError("Please implement encode method!")
def describe_with_prompt(self, image, prompt=None):
raise NotImplementedError("Please implement encode method!")
def chat(self, system, history, gen_conf, image=""):
if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@ -90,7 +95,7 @@ class Base(ABC):
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
def image2base64(self, image):
if isinstance(image, bytes):
return base64.b64encode(image).decode("utf-8")
@ -122,6 +127,25 @@ class Base(ABC):
}
]
def vision_llm_prompt(self, b64, prompt=None):
return [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64}"
},
},
{
"type": "text",
"text": prompt if prompt else vision_llm_describe_prompt(),
},
],
}
]
def chat_prompt(self, text, b64):
return [
{
@ -140,12 +164,12 @@ class Base(ABC):
class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
if not base_url:
base_url="https://api.openai.com/v1"
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
def describe(self, image, max_tokens=300):
def describe(self, image):
b64 = self.image2base64(image)
prompt = self.prompt(b64)
for i in range(len(prompt)):
@ -159,6 +183,16 @@ class GptV4(Base):
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
res = self.client.chat.completions.create(
model=self.model_name,
messages=vision_prompt,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
class AzureGptV4(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
@ -168,7 +202,7 @@ class AzureGptV4(Base):
self.model_name = model_name
self.lang = lang
def describe(self, image, max_tokens=300):
def describe(self, image):
b64 = self.image2base64(image)
prompt = self.prompt(b64)
for i in range(len(prompt)):
@ -182,6 +216,16 @@ class AzureGptV4(Base):
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
res = self.client.chat.completions.create(
model=self.model_name,
messages=vision_prompt,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
class QWenCV(Base):
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
@ -212,23 +256,57 @@ class QWenCV(Base):
}
]
def vision_llm_prompt(self, binary, prompt=None):
# stupid as hell
tmp_dir = get_project_base_directory("tmp")
if not os.path.exists(tmp_dir):
os.mkdir(tmp_dir)
path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
Image.open(io.BytesIO(binary)).save(path)
return [
{
"role": "user",
"content": [
{
"image": f"file://{path}"
},
{
"text": prompt if prompt else vision_llm_describe_prompt(),
},
],
}
]
def chat_prompt(self, text, b64):
return [
{"image": f"{b64}"},
{"text": text},
]
def describe(self, image, max_tokens=300):
def describe(self, image):
from http import HTTPStatus
from dashscope import MultiModalConversation
response = MultiModalConversation.call(model=self.model_name,
messages=self.prompt(image))
response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image))
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
return response.message, 0
def describe_with_prompt(self, image, prompt=None):
from http import HTTPStatus
from dashscope import MultiModalConversation
vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image)
response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt)
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
return response.message, 0
def chat(self, system, history, gen_conf, image=""):
from http import HTTPStatus
from dashscope import MultiModalConversation
if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@ -254,6 +332,7 @@ class QWenCV(Base):
def chat_streamly(self, system, history, gen_conf, image=""):
from http import HTTPStatus
from dashscope import MultiModalConversation
if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@ -292,15 +371,25 @@ class Zhipu4V(Base):
self.model_name = model_name
self.lang = lang
def describe(self, image, max_tokens=1024):
def describe(self, image):
b64 = self.image2base64(image)
prompt = self.prompt(b64)
prompt[0]["content"][1]["type"] = "text"
res = self.client.chat.completions.create(
model=self.model_name,
messages=prompt
messages=prompt,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
res = self.client.chat.completions.create(
model=self.model_name,
messages=vision_prompt
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
@ -334,7 +423,7 @@ class Zhipu4V(Base):
his["content"] = self.chat_prompt(his["content"], image)
response = self.client.chat.completions.create(
model=self.model_name,
model=self.model_name,
messages=history,
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7),
@ -364,7 +453,7 @@ class OllamaCV(Base):
self.model_name = model_name
self.lang = lang
def describe(self, image, max_tokens=1024):
def describe(self, image):
prompt = self.prompt("")
try:
response = self.client.generate(
@ -377,6 +466,19 @@ class OllamaCV(Base):
except Exception as e:
return "**ERROR**: " + str(e), 0
def describe_with_prompt(self, image, prompt=None):
vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
try:
response = self.client.generate(
model=self.model_name,
prompt=vision_prompt[0]["content"][1]["text"],
images=[image],
)
ans = response["response"].strip()
return ans, 128
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat(self, system, history, gen_conf, image=""):
if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@ -460,7 +562,7 @@ class XinferenceCV(Base):
self.model_name = model_name
self.lang = lang
def describe(self, image, max_tokens=300):
def describe(self, image):
b64 = self.image2base64(image)
res = self.client.chat.completions.create(
@ -469,27 +571,49 @@ class XinferenceCV(Base):
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
res = self.client.chat.completions.create(
model=self.model_name,
messages=vision_prompt,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
class GeminiCV(Base):
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
from google.generativeai import client, GenerativeModel
from google.generativeai import GenerativeModel, client
client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = model_name
self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client
self.lang = lang
self.lang = lang
def describe(self, image, max_tokens=2048):
def describe(self, image):
from PIL.Image import open
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
b64 = self.image2base64(image)
img = open(BytesIO(base64.b64decode(b64)))
input = [prompt,img]
b64 = self.image2base64(image)
img = open(BytesIO(base64.b64decode(b64)))
input = [prompt, img]
res = self.model.generate_content(
input
)
return res.text,res.usage_metadata.total_token_count
return res.text, res.usage_metadata.total_token_count
def describe_with_prompt(self, image, prompt=None):
from PIL.Image import open
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
img = open(BytesIO(base64.b64decode(b64)))
input = [vision_prompt, img]
res = self.model.generate_content(
input,
)
return res.text, res.usage_metadata.total_token_count
def chat(self, system, history, gen_conf, image=""):
from transformers import GenerationConfig
@ -566,7 +690,7 @@ class LocalCV(Base):
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
pass
def describe(self, image, max_tokens=1024):
def describe(self, image):
return "", 0
@ -590,7 +714,7 @@ class NvidiaCV(Base):
)
self.key = key
def describe(self, image, max_tokens=1024):
def describe(self, image):
b64 = self.image2base64(image)
response = requests.post(
url=self.base_url,
@ -609,6 +733,27 @@ class NvidiaCV(Base):
response["usage"]["total_tokens"],
)
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
response = requests.post(
url=self.base_url,
headers={
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {self.key}",
},
json={
"messages": vision_prompt,
},
)
response = response.json()
return (
response["choices"][0]["message"]["content"].strip(),
response["usage"]["total_tokens"],
)
def prompt(self, b64):
return [
{
@ -622,6 +767,17 @@ class NvidiaCV(Base):
}
]
def vision_llm_prompt(self, b64, prompt=None):
return [
{
"role": "user",
"content": (
prompt if prompt else vision_llm_describe_prompt()
)
+ f' <img src="data:image/jpeg;base64,{b64}"/>',
}
]
def chat_prompt(self, text, b64):
return [
{
@ -634,7 +790,7 @@ class NvidiaCV(Base):
class StepFunCV(GptV4):
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
if not base_url:
base_url="https://api.stepfun.com/v1"
base_url = "https://api.stepfun.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
@ -666,18 +822,18 @@ class TogetherAICV(GptV4):
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
if not base_url:
base_url = "https://api.together.xyz/v1"
super().__init__(key, model_name,lang,base_url)
super().__init__(key, model_name, lang, base_url)
class YiCV(GptV4):
def __init__(self, key, model_name, lang="Chinese",base_url="https://api.lingyiwanwu.com/v1",):
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",):
if not base_url:
base_url = "https://api.lingyiwanwu.com/v1"
super().__init__(key, model_name,lang,base_url)
super().__init__(key, model_name, lang, base_url)
class HunyuanCV(Base):
def __init__(self, key, model_name, lang="Chinese",base_url=None):
def __init__(self, key, model_name, lang="Chinese", base_url=None):
from tencentcloud.common import credential
from tencentcloud.hunyuan.v20230901 import hunyuan_client
@ -689,12 +845,12 @@ class HunyuanCV(Base):
self.client = hunyuan_client.HunyuanClient(cred, "")
self.lang = lang
def describe(self, image, max_tokens=4096):
from tencentcloud.hunyuan.v20230901 import models
def describe(self, image):
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
from tencentcloud.hunyuan.v20230901 import models
b64 = self.image2base64(image)
req = models.ChatCompletionsRequest()
params = {"Model": self.model_name, "Messages": self.prompt(b64)}
@ -706,7 +862,24 @@ class HunyuanCV(Base):
return ans, response.Usage.TotalTokens
except TencentCloudSDKException as e:
return ans + "\n**ERROR**: " + str(e), 0
def describe_with_prompt(self, image, prompt=None):
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
from tencentcloud.hunyuan.v20230901 import models
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
req = models.ChatCompletionsRequest()
params = {"Model": self.model_name, "Messages": vision_prompt}
req.from_json_string(json.dumps(params))
ans = ""
try:
response = self.client.ChatCompletions(req)
ans = response.Choices[0].Message.Content
return ans, response.Usage.TotalTokens
except TencentCloudSDKException as e:
return ans + "\n**ERROR**: " + str(e), 0
def prompt(self, b64):
return [
{
@ -725,4 +898,4 @@ class HunyuanCV(Base):
},
],
}
]
]

View File

@ -18,13 +18,13 @@ import json
import logging
import re
from collections import defaultdict
import json_repair
from api import settings
from api.db import LLMType
from api.db.services.document_service import DocumentService
from api.db.services.llm_service import TenantLLMService, LLMBundle
from rag.settings import TAG_FLD
from rag.utils import num_tokens_from_string, encoder
from rag.utils import encoder, num_tokens_from_string
def chunks_format(reference):
@ -44,9 +44,11 @@ def chunks_format(reference):
def llm_id2llm_type(llm_id):
from api.db.services.llm_service import TenantLLMService
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
llm_factories = settings.FACTORY_LLM_INFOS
llm_factories = settings.FACTORY_LLM_INFOS
for llm_factory in llm_factories:
for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]:
@ -92,6 +94,8 @@ def message_fit_in(msg, max_length=4000):
def kb_prompt(kbinfos, max_tokens):
from api.db.services.document_service import DocumentService
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
used_token_count = 0
chunks_num = 0
@ -166,15 +170,15 @@ Overall, while Musk enjoys Dogecoin and often promotes it, he also warns against
def keyword_extraction(chat_mdl, content, topn=3):
prompt = f"""
Role: You're a text analyzer.
Role: You're a text analyzer.
Task: extract the most important keywords/phrases of a given piece of text content.
Requirements:
Requirements:
- Summarize the text content, and give top {topn} important keywords/phrases.
- The keywords MUST be in language of the given piece of text content.
- The keywords are delimited by ENGLISH COMMA.
- Keywords ONLY in output.
### Text Content
### Text Content
{content}
"""
@ -194,9 +198,9 @@ Requirements:
def question_proposal(chat_mdl, content, topn=3):
prompt = f"""
Role: You're a text analyzer.
Role: You're a text analyzer.
Task: propose {topn} questions about a given piece of text content.
Requirements:
Requirements:
- Understand and summarize the text content, and propose top {topn} important questions.
- The questions SHOULD NOT have overlapping meanings.
- The questions SHOULD cover the main content of the text as much as possible.
@ -204,7 +208,7 @@ Requirements:
- One question per line.
- Question ONLY in output.
### Text Content
### Text Content
{content}
"""
@ -223,6 +227,8 @@ Requirements:
def full_question(tenant_id, llm_id, messages, language=None):
from api.db.services.llm_service import LLMBundle
if llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else:
@ -239,7 +245,7 @@ def full_question(tenant_id, llm_id, messages, language=None):
prompt = f"""
Role: A helpful assistant
Task and steps:
Task and steps:
1. Generate a full user question that would follow the conversation.
2. If the user's question involves relative date, you need to convert it into absolute date based on the current date, which is {today}. For example: 'yesterday' would be converted to {yesterday}.
@ -300,11 +306,11 @@ Output: What's the weather in Rochester on {tomorrow}?
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
prompt = f"""
Role: You're a text analyzer.
Role: You're a text analyzer.
Task: Tag (put on some labels) to a given piece of text content based on the examples and the entire tag set.
Steps::
Steps::
- Comprehend the tag/label set.
- Comprehend examples which all consist of both text content and assigned tags with relevance score in format of JSON.
- Summarize the text content, and tag it with top {topn} most relevant tags from the set of tag/label and the corresponding relevance score.
@ -358,3 +364,32 @@ Output:
except Exception as e:
logging.exception(f"JSON parsing error: {result} -> {e}")
raise e
def vision_llm_describe_prompt(page=None) -> str:
prompt_en = """
INSTRUCTION:
Transcribe the content from the provided PDF page image into clean Markdown format.
- Only output the content transcribed from the image.
- Do NOT output this instruction or any other explanation.
- If the content is missing or you do not understand the input, return an empty string.
RULES:
1. Do NOT generate examples, demonstrations, or templates.
2. Do NOT output any extra text such as 'Example', 'Example Output', or similar.
3. Do NOT generate any tables, headings, or content that is not explicitly present in the image.
4. Transcribe content word-for-word. Do NOT modify, translate, or omit any content.
5. Do NOT explain Markdown or mention that you are using Markdown.
6. Do NOT wrap the output in ```markdown or ``` blocks.
7. Only apply Markdown structure to headings, paragraphs, lists, and tables, strictly based on the layout of the image. Do NOT create tables unless an actual table exists in the image.
8. Preserve the original language, information, and order exactly as shown in the image.
"""
if page is not None:
prompt_en += f"\nAt the end of the transcription, add the page divider: `--- Page {page} ---`."
prompt_en += """
FAILURE HANDLING:
- If you do not detect valid content in the image, return an empty string.
"""
return prompt_en

View File

@ -16,7 +16,9 @@
import os
import re
import tiktoken
from api.utils.file_utils import get_project_base_directory
@ -54,7 +56,7 @@ def findMaxDt(fnm):
pass
return m
def findMaxTm(fnm):
m = 0
try:
@ -91,11 +93,18 @@ def truncate(string: str, max_len: int) -> str:
"""Returns truncated text if the length of text exceed max_len."""
return encoder.decode(encoder.encode(string)[:max_len])
def clean_markdown_block(text):
text = re.sub(r'^\s*```markdown\s*\n?', '', text)
text = re.sub(r'\n?\s*```\s*$', '', text)
return text.strip()
def get_float(v: str | None):
if v is None:
return float('-inf')
try:
return float(v)
except Exception:
return float('-inf')
return float('-inf')