Feat: add vision LLM PDF parser (#6173)

### What problem does this PR solve?

Add vision LLM PDF parser

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
Yongteng Lei
2025-03-18 14:52:20 +08:00
committed by GitHub
parent 897fe85b5c
commit 5cf610af40
7 changed files with 413 additions and 102 deletions

View File

@ -13,31 +13,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI
import io
from abc import ABC
from ollama import Client
from PIL import Image
from openai import OpenAI
import os
import base64
from io import BytesIO
import io
import json
import requests
import os
from abc import ABC
from io import BytesIO
import requests
from ollama import Client
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from PIL import Image
from zhipuai import ZhipuAI
from rag.nlp import is_english
from api.utils import get_uuid
from api.utils.file_utils import get_project_base_directory
from rag.nlp import is_english
from rag.prompts import vision_llm_describe_prompt
class Base(ABC):
def __init__(self, key, model_name):
pass
def describe(self, image, max_tokens=300):
def describe(self, image):
raise NotImplementedError("Please implement encode method!")
def describe_with_prompt(self, image, prompt=None):
raise NotImplementedError("Please implement encode method!")
def chat(self, system, history, gen_conf, image=""):
if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@ -90,7 +95,7 @@ class Base(ABC):
yield ans + "\n**ERROR**: " + str(e)
yield tk_count
def image2base64(self, image):
if isinstance(image, bytes):
return base64.b64encode(image).decode("utf-8")
@ -122,6 +127,25 @@ class Base(ABC):
}
]
def vision_llm_prompt(self, b64, prompt=None):
return [
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{b64}"
},
},
{
"type": "text",
"text": prompt if prompt else vision_llm_describe_prompt(),
},
],
}
]
def chat_prompt(self, text, b64):
return [
{
@ -140,12 +164,12 @@ class Base(ABC):
class GptV4(Base):
def __init__(self, key, model_name="gpt-4-vision-preview", lang="Chinese", base_url="https://api.openai.com/v1"):
if not base_url:
base_url="https://api.openai.com/v1"
base_url = "https://api.openai.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
def describe(self, image, max_tokens=300):
def describe(self, image):
b64 = self.image2base64(image)
prompt = self.prompt(b64)
for i in range(len(prompt)):
@ -159,6 +183,16 @@ class GptV4(Base):
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
res = self.client.chat.completions.create(
model=self.model_name,
messages=vision_prompt,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
class AzureGptV4(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
@ -168,7 +202,7 @@ class AzureGptV4(Base):
self.model_name = model_name
self.lang = lang
def describe(self, image, max_tokens=300):
def describe(self, image):
b64 = self.image2base64(image)
prompt = self.prompt(b64)
for i in range(len(prompt)):
@ -182,6 +216,16 @@ class AzureGptV4(Base):
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
res = self.client.chat.completions.create(
model=self.model_name,
messages=vision_prompt,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
class QWenCV(Base):
def __init__(self, key, model_name="qwen-vl-chat-v1", lang="Chinese", **kwargs):
@ -212,23 +256,57 @@ class QWenCV(Base):
}
]
def vision_llm_prompt(self, binary, prompt=None):
# stupid as hell
tmp_dir = get_project_base_directory("tmp")
if not os.path.exists(tmp_dir):
os.mkdir(tmp_dir)
path = os.path.join(tmp_dir, "%s.jpg" % get_uuid())
Image.open(io.BytesIO(binary)).save(path)
return [
{
"role": "user",
"content": [
{
"image": f"file://{path}"
},
{
"text": prompt if prompt else vision_llm_describe_prompt(),
},
],
}
]
def chat_prompt(self, text, b64):
return [
{"image": f"{b64}"},
{"text": text},
]
def describe(self, image, max_tokens=300):
def describe(self, image):
from http import HTTPStatus
from dashscope import MultiModalConversation
response = MultiModalConversation.call(model=self.model_name,
messages=self.prompt(image))
response = MultiModalConversation.call(model=self.model_name, messages=self.prompt(image))
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
return response.message, 0
def describe_with_prompt(self, image, prompt=None):
from http import HTTPStatus
from dashscope import MultiModalConversation
vision_prompt = self.vision_llm_prompt(image, prompt) if prompt else self.vision_llm_prompt(image)
response = MultiModalConversation.call(model=self.model_name, messages=vision_prompt)
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content'][0]["text"], response.usage.output_tokens
return response.message, 0
def chat(self, system, history, gen_conf, image=""):
from http import HTTPStatus
from dashscope import MultiModalConversation
if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@ -254,6 +332,7 @@ class QWenCV(Base):
def chat_streamly(self, system, history, gen_conf, image=""):
from http import HTTPStatus
from dashscope import MultiModalConversation
if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@ -292,15 +371,25 @@ class Zhipu4V(Base):
self.model_name = model_name
self.lang = lang
def describe(self, image, max_tokens=1024):
def describe(self, image):
b64 = self.image2base64(image)
prompt = self.prompt(b64)
prompt[0]["content"][1]["type"] = "text"
res = self.client.chat.completions.create(
model=self.model_name,
messages=prompt
messages=prompt,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
res = self.client.chat.completions.create(
model=self.model_name,
messages=vision_prompt
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
@ -334,7 +423,7 @@ class Zhipu4V(Base):
his["content"] = self.chat_prompt(his["content"], image)
response = self.client.chat.completions.create(
model=self.model_name,
model=self.model_name,
messages=history,
temperature=gen_conf.get("temperature", 0.3),
top_p=gen_conf.get("top_p", 0.7),
@ -364,7 +453,7 @@ class OllamaCV(Base):
self.model_name = model_name
self.lang = lang
def describe(self, image, max_tokens=1024):
def describe(self, image):
prompt = self.prompt("")
try:
response = self.client.generate(
@ -377,6 +466,19 @@ class OllamaCV(Base):
except Exception as e:
return "**ERROR**: " + str(e), 0
def describe_with_prompt(self, image, prompt=None):
vision_prompt = self.vision_llm_prompt("", prompt) if prompt else self.vision_llm_prompt("")
try:
response = self.client.generate(
model=self.model_name,
prompt=vision_prompt[0]["content"][1]["text"],
images=[image],
)
ans = response["response"].strip()
return ans, 128
except Exception as e:
return "**ERROR**: " + str(e), 0
def chat(self, system, history, gen_conf, image=""):
if system:
history[-1]["content"] = system + history[-1]["content"] + "user query: " + history[-1]["content"]
@ -460,7 +562,7 @@ class XinferenceCV(Base):
self.model_name = model_name
self.lang = lang
def describe(self, image, max_tokens=300):
def describe(self, image):
b64 = self.image2base64(image)
res = self.client.chat.completions.create(
@ -469,27 +571,49 @@ class XinferenceCV(Base):
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
res = self.client.chat.completions.create(
model=self.model_name,
messages=vision_prompt,
)
return res.choices[0].message.content.strip(), res.usage.total_tokens
class GeminiCV(Base):
def __init__(self, key, model_name="gemini-1.0-pro-vision-latest", lang="Chinese", **kwargs):
from google.generativeai import client, GenerativeModel
from google.generativeai import GenerativeModel, client
client.configure(api_key=key)
_client = client.get_default_generative_client()
self.model_name = model_name
self.model = GenerativeModel(model_name=self.model_name)
self.model._client = _client
self.lang = lang
self.lang = lang
def describe(self, image, max_tokens=2048):
def describe(self, image):
from PIL.Image import open
prompt = "请用中文详细描述一下图中的内容,比如时间,地点,人物,事情,人物心情等,如果有数据请提取出数据。" if self.lang.lower() == "chinese" else \
"Please describe the content of this picture, like where, when, who, what happen. If it has number data, please extract them out."
b64 = self.image2base64(image)
img = open(BytesIO(base64.b64decode(b64)))
input = [prompt,img]
b64 = self.image2base64(image)
img = open(BytesIO(base64.b64decode(b64)))
input = [prompt, img]
res = self.model.generate_content(
input
)
return res.text,res.usage_metadata.total_token_count
return res.text, res.usage_metadata.total_token_count
def describe_with_prompt(self, image, prompt=None):
from PIL.Image import open
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
img = open(BytesIO(base64.b64decode(b64)))
input = [vision_prompt, img]
res = self.model.generate_content(
input,
)
return res.text, res.usage_metadata.total_token_count
def chat(self, system, history, gen_conf, image=""):
from transformers import GenerationConfig
@ -566,7 +690,7 @@ class LocalCV(Base):
def __init__(self, key, model_name="glm-4v", lang="Chinese", **kwargs):
pass
def describe(self, image, max_tokens=1024):
def describe(self, image):
return "", 0
@ -590,7 +714,7 @@ class NvidiaCV(Base):
)
self.key = key
def describe(self, image, max_tokens=1024):
def describe(self, image):
b64 = self.image2base64(image)
response = requests.post(
url=self.base_url,
@ -609,6 +733,27 @@ class NvidiaCV(Base):
response["usage"]["total_tokens"],
)
def describe_with_prompt(self, image, prompt=None):
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
response = requests.post(
url=self.base_url,
headers={
"accept": "application/json",
"content-type": "application/json",
"Authorization": f"Bearer {self.key}",
},
json={
"messages": vision_prompt,
},
)
response = response.json()
return (
response["choices"][0]["message"]["content"].strip(),
response["usage"]["total_tokens"],
)
def prompt(self, b64):
return [
{
@ -622,6 +767,17 @@ class NvidiaCV(Base):
}
]
def vision_llm_prompt(self, b64, prompt=None):
return [
{
"role": "user",
"content": (
prompt if prompt else vision_llm_describe_prompt()
)
+ f' <img src="data:image/jpeg;base64,{b64}"/>',
}
]
def chat_prompt(self, text, b64):
return [
{
@ -634,7 +790,7 @@ class NvidiaCV(Base):
class StepFunCV(GptV4):
def __init__(self, key, model_name="step-1v-8k", lang="Chinese", base_url="https://api.stepfun.com/v1"):
if not base_url:
base_url="https://api.stepfun.com/v1"
base_url = "https://api.stepfun.com/v1"
self.client = OpenAI(api_key=key, base_url=base_url)
self.model_name = model_name
self.lang = lang
@ -666,18 +822,18 @@ class TogetherAICV(GptV4):
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.together.xyz/v1"):
if not base_url:
base_url = "https://api.together.xyz/v1"
super().__init__(key, model_name,lang,base_url)
super().__init__(key, model_name, lang, base_url)
class YiCV(GptV4):
def __init__(self, key, model_name, lang="Chinese",base_url="https://api.lingyiwanwu.com/v1",):
def __init__(self, key, model_name, lang="Chinese", base_url="https://api.lingyiwanwu.com/v1",):
if not base_url:
base_url = "https://api.lingyiwanwu.com/v1"
super().__init__(key, model_name,lang,base_url)
super().__init__(key, model_name, lang, base_url)
class HunyuanCV(Base):
def __init__(self, key, model_name, lang="Chinese",base_url=None):
def __init__(self, key, model_name, lang="Chinese", base_url=None):
from tencentcloud.common import credential
from tencentcloud.hunyuan.v20230901 import hunyuan_client
@ -689,12 +845,12 @@ class HunyuanCV(Base):
self.client = hunyuan_client.HunyuanClient(cred, "")
self.lang = lang
def describe(self, image, max_tokens=4096):
from tencentcloud.hunyuan.v20230901 import models
def describe(self, image):
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
from tencentcloud.hunyuan.v20230901 import models
b64 = self.image2base64(image)
req = models.ChatCompletionsRequest()
params = {"Model": self.model_name, "Messages": self.prompt(b64)}
@ -706,7 +862,24 @@ class HunyuanCV(Base):
return ans, response.Usage.TotalTokens
except TencentCloudSDKException as e:
return ans + "\n**ERROR**: " + str(e), 0
def describe_with_prompt(self, image, prompt=None):
from tencentcloud.common.exception.tencent_cloud_sdk_exception import TencentCloudSDKException
from tencentcloud.hunyuan.v20230901 import models
b64 = self.image2base64(image)
vision_prompt = self.vision_llm_prompt(b64, prompt) if prompt else self.vision_llm_prompt(b64)
req = models.ChatCompletionsRequest()
params = {"Model": self.model_name, "Messages": vision_prompt}
req.from_json_string(json.dumps(params))
ans = ""
try:
response = self.client.ChatCompletions(req)
ans = response.Choices[0].Message.Content
return ans, response.Usage.TotalTokens
except TencentCloudSDKException as e:
return ans + "\n**ERROR**: " + str(e), 0
def prompt(self, b64):
return [
{
@ -725,4 +898,4 @@ class HunyuanCV(Base):
},
],
}
]
]