diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index b37028428..15c8943d8 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -14,6 +14,7 @@ # limitations under the License. # +import re import base64 import json import os @@ -32,7 +33,6 @@ from rag.nlp import is_english from rag.prompts.generator import vision_llm_describe_prompt from common.token_utils import num_tokens_from_string, total_token_count_from_response - class Base(ABC): def __init__(self, **kwargs): # Configure retry parameters @@ -208,6 +208,7 @@ class GptV4(Base): model=self.model_name, messages=self.prompt(b64), extra_body=self.extra_body, + unused = None, ) return res.choices[0].message.content.strip(), total_token_count_from_response(res) @@ -324,6 +325,122 @@ class Zhipu4V(GptV4): Base.__init__(self, **kwargs) + def _clean_conf(self, gen_conf): + if "max_tokens" in gen_conf: + del gen_conf["max_tokens"] + gen_conf = self._clean_conf_plealty(gen_conf) + return gen_conf + + + def _clean_conf_plealty(self, gen_conf): + if "presence_penalty" in gen_conf: + del gen_conf["presence_penalty"] + if "frequency_penalty" in gen_conf: + del gen_conf["frequency_penalty"] + return gen_conf + + + def _request(self, msg, stream, gen_conf={}): + response = requests.post( + self.base_url, + json={ + "model": self.model_name, + "messages": msg, + "stream": stream, + **gen_conf + }, + headers= { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + ) + return response.json() + + + def chat(self, system, history, gen_conf, images=None, stream=False, **kwargs): + if system and history and history[0].get("role") != "system": + history.insert(0, {"role": "system", "content": system}) + + gen_conf = self._clean_conf(gen_conf) + + logging.info(json.dumps(history, ensure_ascii=False, indent=2)) + response = self.client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=False, **gen_conf) + content = response.choices[0].message.content.strip() + + cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip() + return cleaned, total_token_count_from_response(response) + + + def chat_streamly(self, system, history, gen_conf, images=None, **kwargs): + from rag.llm.chat_model import LENGTH_NOTIFICATION_CN, LENGTH_NOTIFICATION_EN + from rag.nlp import is_chinese + + if system and history and history[0].get("role") != "system": + history.insert(0, {"role": "system", "content": system}) + gen_conf = self._clean_conf(gen_conf) + ans = "" + tk_count = 0 + try: + logging.info(json.dumps(history, ensure_ascii=False, indent=2)) + response = self.client.chat.completions.create(model=self.model_name, messages=self._form_history(system, history, images), stream=True, **gen_conf) + for resp in response: + if not resp.choices[0].delta.content: + continue + delta = resp.choices[0].delta.content + ans = delta + if resp.choices[0].finish_reason == "length": + if is_chinese(ans): + ans += LENGTH_NOTIFICATION_CN + else: + ans += LENGTH_NOTIFICATION_EN + tk_count = total_token_count_from_response(resp) + if resp.choices[0].finish_reason == "stop": + tk_count = total_token_count_from_response(resp) + yield ans + except Exception as e: + yield ans + "\n**ERROR**: " + str(e) + + yield tk_count + + + def describe(self, image): + return self.describe_with_prompt(image) + + + def describe_with_prompt(self, image, prompt=None): + b64 = self.image2base64(image) + if prompt is None: + prompt = "Describe this image." + + # Chat messages + messages = [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { "url": b64 } + }, + { + "type": "text", + "text": prompt + } + ] + } + ] + + resp = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + stream=False + ) + + content = resp.choices[0].message.content.strip() + cleaned = re.sub(r"<\|(begin_of_box|end_of_box)\|>", "", content).strip() + + return cleaned, num_tokens_from_string(cleaned) + + class StepFunCV(GptV4): _FACTORY_NAME = "StepFun"