From df3890827d7419d729e8c5fce16db6c148a2896a Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Wed, 26 Mar 2025 19:33:14 +0800 Subject: [PATCH] Refa: change LLM chat output from full to delta (incremental) (#6534) ### What problem does this PR solve? Change LLM chat output from full to delta (incremental) ### Type of change - [x] Refactoring --- api/apps/sdk/session.py | 202 ++++++++------- api/db/services/llm_service.py | 13 +- rag/llm/chat_model.py | 461 ++++++++++++--------------------- 3 files changed, 277 insertions(+), 399 deletions(-) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 57e1aebc0..e0bcfa735 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -13,31 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import re import json +import re import time -from api.db import LLMType -from api.db.services.conversation_service import ConversationService, iframe_completion -from api.db.services.conversation_service import completion as rag_completion -from api.db.services.canvas_service import completion as agent_completion -from api.db.services.dialog_service import ask, chat +from flask import Response, jsonify, request + from agent.canvas import Canvas -from api.db import StatusEnum +from api.db import LLMType, StatusEnum from api.db.db_models import APIToken from api.db.services.api_service import API4ConversationService from api.db.services.canvas_service import UserCanvasService -from api.db.services.dialog_service import DialogService -from api.db.services.knowledgebase_service import KnowledgebaseService -from api.utils import get_uuid -from api.utils.api_utils import get_error_data_result, validate_request -from api.utils.api_utils import get_result, token_required -from api.db.services.llm_service import LLMBundle +from api.db.services.canvas_service import completion as agent_completion +from api.db.services.conversation_service import ConversationService, iframe_completion +from api.db.services.conversation_service import completion as rag_completion +from api.db.services.dialog_service import DialogService, ask, chat from api.db.services.file_service import FileService +from api.db.services.knowledgebase_service import KnowledgebaseService +from api.db.services.llm_service import LLMBundle +from api.utils import get_uuid +from api.utils.api_utils import get_error_data_result, get_result, token_required, validate_request -from flask import jsonify, request, Response -@manager.route('/chats//sessions', methods=['POST']) # noqa: F821 +@manager.route("/chats//sessions", methods=["POST"]) # noqa: F821 @token_required def create(tenant_id, chat_id): req = request.json @@ -50,7 +48,7 @@ def create(tenant_id, chat_id): "dialog_id": req["dialog_id"], "name": req.get("name", "New session"), "message": [{"role": "assistant", "content": dia[0].prompt_config.get("prologue")}], - "user_id": req.get("user_id", "") + "user_id": req.get("user_id", ""), } if not conv.get("name"): return get_error_data_result(message="`name` can not be empty.") @@ -59,20 +57,20 @@ def create(tenant_id, chat_id): if not e: return get_error_data_result(message="Fail to create a session!") conv = conv.to_dict() - conv['messages'] = conv.pop("message") + conv["messages"] = conv.pop("message") conv["chat_id"] = conv.pop("dialog_id") del conv["reference"] return get_result(data=conv) -@manager.route('/agents//sessions', methods=['POST']) # noqa: F821 +@manager.route("/agents//sessions", methods=["POST"]) # noqa: F821 @token_required def create_agent_session(tenant_id, agent_id): req = request.json if not request.is_json: req = request.form files = request.files - user_id = request.args.get('user_id', '') + user_id = request.args.get("user_id", "") e, cvs = UserCanvasService.get_by_id(agent_id) if not e: @@ -113,7 +111,7 @@ def create_agent_session(tenant_id, agent_id): ele.pop("value") else: if req is not None and req.get(ele["key"]): - ele["value"] = req[ele['key']] + ele["value"] = req[ele["key"]] else: if "value" in ele: ele.pop("value") @@ -121,20 +119,13 @@ def create_agent_session(tenant_id, agent_id): for ans in canvas.run(stream=False): pass cvs.dsl = json.loads(str(canvas)) - conv = { - "id": get_uuid(), - "dialog_id": cvs.id, - "user_id": user_id, - "message": [{"role": "assistant", "content": canvas.get_prologue()}], - "source": "agent", - "dsl": cvs.dsl - } + conv = {"id": get_uuid(), "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl} API4ConversationService.save(**conv) conv["agent_id"] = conv.pop("dialog_id") return get_result(data=conv) -@manager.route('/chats//sessions/', methods=['PUT']) # noqa: F821 +@manager.route("/chats//sessions/", methods=["PUT"]) # noqa: F821 @token_required def update(tenant_id, chat_id, session_id): req = request.json @@ -156,14 +147,14 @@ def update(tenant_id, chat_id, session_id): return get_result() -@manager.route('/chats//completions', methods=['POST']) # noqa: F821 +@manager.route("/chats//completions", methods=["POST"]) # noqa: F821 @token_required def chat_completion(tenant_id, chat_id): req = request.json if not req: req = {"question": ""} if not req.get("session_id"): - req["question"]="" + req["question"] = "" if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): return get_error_data_result(f"You don't own the chat {chat_id}") if req.get("session_id"): @@ -185,7 +176,7 @@ def chat_completion(tenant_id, chat_id): return get_result(data=answer) -@manager.route('/chats_openai//chat/completions', methods=['POST']) # noqa: F821 +@manager.route("/chats_openai//chat/completions", methods=["POST"]) # noqa: F821 @validate_request("model", "messages") # noqa: F821 @token_required def chat_completion_openai_like(tenant_id, chat_id): @@ -260,35 +251,60 @@ def chat_completion_openai_like(tenant_id, chat_id): def streamed_response_generator(chat_id, dia, msg): token_used = 0 answer_cache = "" + reasoning_cache = "" response = { "id": f"chatcmpl-{chat_id}", - "choices": [ - { - "delta": { - "content": "", - "role": "assistant", - "function_call": None, - "tool_calls": None - }, - "finish_reason": None, - "index": 0, - "logprobs": None - } - ], + "choices": [{"delta": {"content": "", "role": "assistant", "function_call": None, "tool_calls": None, "reasoning_content": ""}, "finish_reason": None, "index": 0, "logprobs": None}], "created": int(time.time()), "model": "model", "object": "chat.completion.chunk", "system_fingerprint": "", - "usage": None + "usage": None, } try: for ans in chat(dia, msg, True): answer = ans["answer"] - incremental = answer.replace(answer_cache, "", 1) - answer_cache = answer.rstrip("") - token_used += len(incremental) - response["choices"][0]["delta"]["content"] = incremental + + reasoning_match = re.search(r"(.*?)", answer, flags=re.DOTALL) + if reasoning_match: + reasoning_part = reasoning_match.group(1) + content_part = answer[reasoning_match.end() :] + else: + reasoning_part = "" + content_part = answer + + reasoning_incremental = "" + if reasoning_part: + if reasoning_part.startswith(reasoning_cache): + reasoning_incremental = reasoning_part.replace(reasoning_cache, "", 1) + else: + reasoning_incremental = reasoning_part + reasoning_cache = reasoning_part + + content_incremental = "" + if content_part: + if content_part.startswith(answer_cache): + content_incremental = content_part.replace(answer_cache, "", 1) + else: + content_incremental = content_part + answer_cache = content_part + + token_used += len(reasoning_incremental) + len(content_incremental) + + if not any([reasoning_incremental, content_incremental]): + continue + + if reasoning_incremental: + response["choices"][0]["delta"]["reasoning_content"] = reasoning_incremental + else: + response["choices"][0]["delta"]["reasoning_content"] = None + + if content_incremental: + response["choices"][0]["delta"]["content"] = content_incremental + else: + response["choices"][0]["delta"]["content"] = None + yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" except Exception as e: response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e) @@ -296,16 +312,12 @@ def chat_completion_openai_like(tenant_id, chat_id): # The last chunk response["choices"][0]["delta"]["content"] = None + response["choices"][0]["delta"]["reasoning_content"] = None response["choices"][0]["finish_reason"] = "stop" - response["usage"] = { - "prompt_tokens": len(prompt), - "completion_tokens": token_used, - "total_tokens": len(prompt) + token_used - } + response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used} yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" yield "data:[DONE]\n\n" - resp = Response(streamed_response_generator(chat_id, dia, msg), mimetype="text/event-stream") resp.headers.add_header("Cache-control", "no-cache") resp.headers.add_header("Connection", "keep-alive") @@ -320,7 +332,7 @@ def chat_completion_openai_like(tenant_id, chat_id): break content = answer["answer"] - response = { + response = { "id": f"chatcmpl-{chat_id}", "object": "chat.completion", "created": int(time.time()), @@ -332,25 +344,15 @@ def chat_completion_openai_like(tenant_id, chat_id): "completion_tokens_details": { "reasoning_tokens": context_token_used, "accepted_prediction_tokens": len(content), - "rejected_prediction_tokens": 0 # 0 for simplicity - } + "rejected_prediction_tokens": 0, # 0 for simplicity + }, }, - "choices": [ - { - "message": { - "role": "assistant", - "content": content - }, - "logprobs": None, - "finish_reason": "stop", - "index": 0 - } - ] + "choices": [{"message": {"role": "assistant", "content": content}, "logprobs": None, "finish_reason": "stop", "index": 0}], } return jsonify(response) -@manager.route('/agents//completions', methods=['POST']) # noqa: F821 +@manager.route("/agents//completions", methods=["POST"]) # noqa: F821 @token_required def agent_completions(tenant_id, agent_id): req = request.json @@ -361,8 +363,8 @@ def agent_completions(tenant_id, agent_id): dsl = cvs[0].dsl if not isinstance(dsl, str): dsl = json.dumps(dsl) - #canvas = Canvas(dsl, tenant_id) - #if canvas.get_preset_param(): + # canvas = Canvas(dsl, tenant_id) + # if canvas.get_preset_param(): # req["question"] = "" conv = API4ConversationService.query(id=req["session_id"], dialog_id=agent_id) if not conv: @@ -376,9 +378,7 @@ def agent_completions(tenant_id, agent_id): states = {field: current_dsl.get(field, []) for field in state_fields} current_dsl.update(new_dsl) current_dsl.update(states) - API4ConversationService.update_by_id(req["session_id"], { - "dsl": current_dsl - }) + API4ConversationService.update_by_id(req["session_id"], {"dsl": current_dsl}) else: req["question"] = "" if req.get("stream", True): @@ -395,7 +395,7 @@ def agent_completions(tenant_id, agent_id): return get_error_data_result(str(e)) -@manager.route('/chats//sessions', methods=['GET']) # noqa: F821 +@manager.route("/chats//sessions", methods=["GET"]) # noqa: F821 @token_required def list_session(tenant_id, chat_id): if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value): @@ -414,7 +414,7 @@ def list_session(tenant_id, chat_id): if not convs: return get_result(data=[]) for conv in convs: - conv['messages'] = conv.pop("message") + conv["messages"] = conv.pop("message") infos = conv["messages"] for info in infos: if "prompt" in info: @@ -448,7 +448,7 @@ def list_session(tenant_id, chat_id): return get_result(data=convs) -@manager.route('/agents//sessions', methods=['GET']) # noqa: F821 +@manager.route("/agents//sessions", methods=["GET"]) # noqa: F821 @token_required def list_agent_session(tenant_id, agent_id): if not UserCanvasService.query(user_id=tenant_id, id=agent_id): @@ -464,12 +464,11 @@ def list_agent_session(tenant_id, agent_id): desc = True # dsl defaults to True in all cases except for False and false include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false" - convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, - user_id, include_dsl) + convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, user_id, include_dsl) if not convs: return get_result(data=[]) for conv in convs: - conv['messages'] = conv.pop("message") + conv["messages"] = conv.pop("message") infos = conv["messages"] for info in infos: if "prompt" in info: @@ -502,7 +501,7 @@ def list_agent_session(tenant_id, agent_id): return get_result(data=convs) -@manager.route('/chats//sessions', methods=["DELETE"]) # noqa: F821 +@manager.route("/chats//sessions", methods=["DELETE"]) # noqa: F821 @token_required def delete(tenant_id, chat_id): if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value): @@ -528,14 +527,14 @@ def delete(tenant_id, chat_id): return get_result() -@manager.route('/agents//sessions', methods=["DELETE"]) # noqa: F821 +@manager.route("/agents//sessions", methods=["DELETE"]) # noqa: F821 @token_required def delete_agent_session(tenant_id, agent_id): req = request.json cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id) if not cvs: return get_error_data_result(f"You don't own the agent {agent_id}") - + convs = API4ConversationService.query(dialog_id=agent_id) if not convs: return get_error_data_result(f"Agent {agent_id} has no sessions") @@ -551,16 +550,16 @@ def delete_agent_session(tenant_id, agent_id): conv_list.append(conv.id) else: conv_list = ids - + for session_id in conv_list: conv = API4ConversationService.query(id=session_id, dialog_id=agent_id) if not conv: return get_error_data_result(f"The agent doesn't own the session ${session_id}") API4ConversationService.delete_by_id(session_id) return get_result() - -@manager.route('/sessions/ask', methods=['POST']) # noqa: F821 + +@manager.route("/sessions/ask", methods=["POST"]) # noqa: F821 @token_required def ask_about(tenant_id): req = request.json @@ -586,9 +585,7 @@ def ask_about(tenant_id): for ans in ask(req["question"], req["kb_ids"], uid): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: - yield "data:" + json.dumps({"code": 500, "message": str(e), - "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, - ensure_ascii=False) + "\n\n" + yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" resp = Response(stream(), mimetype="text/event-stream") @@ -599,7 +596,7 @@ def ask_about(tenant_id): return resp -@manager.route('/sessions/related_questions', methods=['POST']) # noqa: F821 +@manager.route("/sessions/related_questions", methods=["POST"]) # noqa: F821 @token_required def related_questions(tenant_id): req = request.json @@ -631,18 +628,27 @@ Reason: - At the same time, related terms can also help search engines better understand user needs and return more accurate search results. """ - ans = chat_mdl.chat(prompt, [{"role": "user", "content": f""" + ans = chat_mdl.chat( + prompt, + [ + { + "role": "user", + "content": f""" Keywords: {question} Related search terms: - """}], {"temperature": 0.9}) + """, + } + ], + {"temperature": 0.9}, + ) return get_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) -@manager.route('/chatbots//completions', methods=['POST']) # noqa: F821 +@manager.route("/chatbots//completions", methods=["POST"]) # noqa: F821 def chatbot_completions(dialog_id): req = request.json - token = request.headers.get('Authorization').split() + token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') token = token[1] @@ -665,11 +671,11 @@ def chatbot_completions(dialog_id): return get_result(data=answer) -@manager.route('/agentbots//completions', methods=['POST']) # noqa: F821 +@manager.route("/agentbots//completions", methods=["POST"]) # noqa: F821 def agent_bot_completions(agent_id): req = request.json - token = request.headers.get('Authorization').split() + token = request.headers.get("Authorization").split() if len(token) != 2: return get_error_data_result(message='Authorization is not valid!"') token = token[1] diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index f5abb776f..d86e0bf0b 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -324,15 +324,18 @@ class LLMBundle: if self.langfuse: generation = self.trace.generation(name="chat_streamly", model=self.llm_name, input={"system": system, "history": history}) - output = "" + ans = "" for txt in self.mdl.chat_streamly(system, history, gen_conf): if isinstance(txt, int): if self.langfuse: - generation.end(output={"output": output}) + generation.end(output={"output": ans}) if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name): logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt)) - return + return ans - output = txt - yield txt + if txt.endswith(""): + ans = ans.rstrip("") + + ans += txt + yield ans diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 6c2f15499..0ed660430 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -1,5 +1,5 @@ # -# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,25 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import re +import asyncio +import json +import logging +import os import random +import re +import time +from abc import ABC +import openai +import requests +from dashscope import Generation +from ollama import Client +from openai import OpenAI from openai.lib.azure import AzureOpenAI from zhipuai import ZhipuAI -from dashscope import Generation -from abc import ABC -from openai import OpenAI -import openai -from ollama import Client + from rag.nlp import is_chinese, is_english from rag.utils import num_tokens_from_string -import os -import json -import requests -import asyncio -import logging -import time - # Error message constants ERROR_PREFIX = "**ERROR**" @@ -53,21 +53,21 @@ LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to class Base(ABC): def __init__(self, key, model_name, base_url): - timeout = int(os.environ.get('LM_TIMEOUT_SECONDS', 600)) + timeout = int(os.environ.get("LM_TIMEOUT_SECONDS", 600)) self.client = OpenAI(api_key=key, base_url=base_url, timeout=timeout) self.model_name = model_name # Configure retry parameters - self.max_retries = int(os.environ.get('LLM_MAX_RETRIES', 5)) - self.base_delay = float(os.environ.get('LLM_BASE_DELAY', 2.0)) + self.max_retries = int(os.environ.get("LLM_MAX_RETRIES", 5)) + self.base_delay = float(os.environ.get("LLM_BASE_DELAY", 2.0)) def _get_delay(self, attempt): """Calculate retry delay time""" - return self.base_delay * (2 ** attempt) + random.uniform(0, 0.5) - + return self.base_delay * (2**attempt) + random.uniform(0, 0.5) + def _classify_error(self, error): """Classify error based on error message content""" error_str = str(error).lower() - + if "rate limit" in error_str or "429" in error_str or "tpm limit" in error_str or "too many requests" in error_str or "requests per minute" in error_str: return ERROR_RATE_LIMIT elif "auth" in error_str or "key" in error_str or "apikey" in error_str or "401" in error_str or "forbidden" in error_str or "permission" in error_str: @@ -98,11 +98,8 @@ class Base(ABC): # Implement exponential backoff retry strategy for attempt in range(self.max_retries): try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - **gen_conf) - + response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf) + if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]): return "", 0 ans = response.choices[0].message.content.strip() @@ -111,17 +108,17 @@ class Base(ABC): ans += LENGTH_NOTIFICATION_CN else: ans += LENGTH_NOTIFICATION_EN - return ans, self.total_token_count(response) + return ans, self.total_token_count(response) except Exception as e: # Classify the error error_code = self._classify_error(e) - + # Check if it's a rate limit error or server error and not the last attempt should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries - 1 - + if should_retry: delay = self._get_delay(attempt) - logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt+1}/{self.max_retries})") + logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})") time.sleep(delay) else: # For non-rate limit errors or the last attempt, return an error message @@ -136,24 +133,23 @@ class Base(ABC): del gen_conf["max_tokens"] ans = "" total_tokens = 0 + reasoning_start = False try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - stream=True, - **gen_conf) + response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf) for resp in response: if not resp.choices: continue if not resp.choices[0].delta.content: resp.choices[0].delta.content = "" if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content: - if ans.find("") < 0: - ans += "" - ans = ans.replace("", "") + ans = "" + if not reasoning_start: + reasoning_start = True + ans = "" ans += resp.choices[0].delta.reasoning_content + "" else: - ans += resp.choices[0].delta.content + reasoning_start = False + ans = resp.choices[0].delta.content tol = self.total_token_count(resp) if not tol: @@ -221,7 +217,7 @@ class ModelScopeChat(Base): def __init__(self, key=None, model_name="", base_url=""): if not base_url: raise ValueError("Local llm url cannot be None") - base_url = base_url.rstrip('/') + base_url = base_url.rstrip("/") if base_url.split("/")[-1] != "v1": base_url = os.path.join(base_url, "v1") super().__init__(key, model_name.split("___")[0], base_url) @@ -236,8 +232,8 @@ class DeepSeekChat(Base): class AzureChat(Base): def __init__(self, key, model_name, **kwargs): - api_key = json.loads(key).get('api_key', '') - api_version = json.loads(key).get('api_version', '2024-02-01') + api_key = json.loads(key).get("api_key", "") + api_version = json.loads(key).get("api_version", "2024-02-01") self.client = AzureOpenAI(api_key=api_key, azure_endpoint=kwargs["base_url"], api_version=api_version) self.model_name = model_name @@ -264,16 +260,9 @@ class BaiChuanChat(Base): response = self.client.chat.completions.create( model=self.model_name, messages=history, - extra_body={ - "tools": [{ - "type": "web_search", - "web_search": { - "enable": True, - "search_mode": "performance_first" - } - }] - }, - **self._format_params(gen_conf)) + extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]}, + **self._format_params(gen_conf), + ) ans = response.choices[0].message.content.strip() if response.choices[0].finish_reason == "length": if is_chinese([ans]): @@ -295,23 +284,16 @@ class BaiChuanChat(Base): response = self.client.chat.completions.create( model=self.model_name, messages=history, - extra_body={ - "tools": [{ - "type": "web_search", - "web_search": { - "enable": True, - "search_mode": "performance_first" - } - }] - }, + extra_body={"tools": [{"type": "web_search", "web_search": {"enable": True, "search_mode": "performance_first"}}]}, stream=True, - **self._format_params(gen_conf)) + **self._format_params(gen_conf), + ) for resp in response: if not resp.choices: continue if not resp.choices[0].delta.content: resp.choices[0].delta.content = "" - ans += resp.choices[0].delta.content + ans = resp.choices[0].delta.content tol = self.total_token_count(resp) if not tol: total_tokens += num_tokens_from_string(resp.choices[0].delta.content) @@ -333,6 +315,7 @@ class BaiChuanChat(Base): class QWenChat(Base): def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs): import dashscope + dashscope.api_key = key self.model_name = model_name if self.is_reasoning_model(self.model_name): @@ -344,22 +327,18 @@ class QWenChat(Base): if self.is_reasoning_model(self.model_name): return super().chat(system, history, gen_conf) - stream_flag = str(os.environ.get('QWEN_CHAT_BY_STREAM', 'true')).lower() == 'true' + stream_flag = str(os.environ.get("QWEN_CHAT_BY_STREAM", "true")).lower() == "true" if not stream_flag: from http import HTTPStatus + if system: history.insert(0, {"role": "system", "content": system}) - response = Generation.call( - self.model_name, - messages=history, - result_format='message', - **gen_conf - ) + response = Generation.call(self.model_name, messages=history, result_format="message", **gen_conf) ans = "" tk_count = 0 if response.status_code == HTTPStatus.OK: - ans += response.output.choices[0]['message']['content'] + ans += response.output.choices[0]["message"]["content"] tk_count += self.total_token_count(response) if response.output.choices[0].get("finish_reason", "") == "length": if is_chinese([ans]): @@ -378,8 +357,9 @@ class QWenChat(Base): else: return "".join(result_list[:-1]), result_list[-1] - def _chat_streamly(self, system, history, gen_conf, incremental_output=False): + def _chat_streamly(self, system, history, gen_conf, incremental_output=True): from http import HTTPStatus + if system: history.insert(0, {"role": "system", "content": system}) if "max_tokens" in gen_conf: @@ -387,17 +367,10 @@ class QWenChat(Base): ans = "" tk_count = 0 try: - response = Generation.call( - self.model_name, - messages=history, - result_format='message', - stream=True, - incremental_output=incremental_output, - **gen_conf - ) + response = Generation.call(self.model_name, messages=history, result_format="message", stream=True, incremental_output=incremental_output, **gen_conf) for resp in response: if resp.status_code == HTTPStatus.OK: - ans = resp.output.choices[0]['message']['content'] + ans = resp.output.choices[0]["message"]["content"] tk_count = self.total_token_count(resp) if resp.output.choices[0].get("finish_reason", "") == "length": if is_chinese(ans): @@ -406,8 +379,11 @@ class QWenChat(Base): ans += LENGTH_NOTIFICATION_EN yield ans else: - yield ans + "\n**ERROR**: " + resp.message if not re.search(r" (key|quota)", - str(resp.message).lower()) else "Out of credit. Please set the API key in **settings > Model providers.**" + yield ( + ans + "\n**ERROR**: " + resp.message + if not re.search(r" (key|quota)", str(resp.message).lower()) + else "Out of credit. Please set the API key in **settings > Model providers.**" + ) except Exception as e: yield ans + "\n**ERROR**: " + str(e) @@ -423,10 +399,12 @@ class QWenChat(Base): @staticmethod def is_reasoning_model(model_name: str) -> bool: - return any([ - model_name.lower().find("deepseek") >= 0, - model_name.lower().find("qwq") >= 0 and model_name.lower() != 'qwq-32b-preview', - ]) + return any( + [ + model_name.lower().find("deepseek") >= 0, + model_name.lower().find("qwq") >= 0 and model_name.lower() != "qwq-32b-preview", + ] + ) class ZhipuChat(Base): @@ -444,11 +422,7 @@ class ZhipuChat(Base): del gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - **gen_conf - ) + response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf) ans = response.choices[0].message.content.strip() if response.choices[0].finish_reason == "length": if is_chinese(ans): @@ -471,17 +445,12 @@ class ZhipuChat(Base): ans = "" tk_count = 0 try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - stream=True, - **gen_conf - ) + response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf) for resp in response: if not resp.choices[0].delta.content: continue delta = resp.choices[0].delta.content - ans += delta + ans = delta if resp.choices[0].finish_reason == "length": if is_chinese(ans): ans += LENGTH_NOTIFICATION_CN @@ -499,8 +468,7 @@ class ZhipuChat(Base): class OllamaChat(Base): def __init__(self, key, model_name, **kwargs): - self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else \ - Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"}) + self.client = Client(host=kwargs["base_url"]) if not key or key == "x" else Client(host=kwargs["base_url"], headers={"Authorization": f"Bear {key}"}) self.model_name = model_name def chat(self, system, history, gen_conf): @@ -509,9 +477,7 @@ class OllamaChat(Base): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] try: - options = { - "num_ctx": 32768 - } + options = {"num_ctx": 32768} if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] if "max_tokens" in gen_conf: @@ -522,12 +488,7 @@ class OllamaChat(Base): options["presence_penalty"] = gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] - response = self.client.chat( - model=self.model_name, - messages=history, - options=options, - keep_alive=-1 - ) + response = self.client.chat(model=self.model_name, messages=history, options=options, keep_alive=-1) ans = response["message"]["content"].strip() return ans, response.get("eval_count", 0) + response.get("prompt_eval_count", 0) except Exception as e: @@ -551,17 +512,11 @@ class OllamaChat(Base): options["frequency_penalty"] = gen_conf["frequency_penalty"] ans = "" try: - response = self.client.chat( - model=self.model_name, - messages=history, - stream=True, - options=options, - keep_alive=-1 - ) + response = self.client.chat(model=self.model_name, messages=history, stream=True, options=options, keep_alive=-1) for resp in response: if resp["done"]: yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) - ans += resp["message"]["content"] + ans = resp["message"]["content"] yield ans except Exception as e: yield ans + "\n**ERROR**: " + str(e) @@ -588,9 +543,7 @@ class LocalLLM(Base): def __conn(self): from multiprocessing.connection import Client - self._connection = Client( - (self.host, self.port), authkey=b"infiniflow-token4kevinhu" - ) + self._connection = Client((self.host, self.port), authkey=b"infiniflow-token4kevinhu") def __getattr__(self, name): import pickle @@ -613,17 +566,17 @@ class LocalLLM(Base): def _prepare_prompt(self, system, history, gen_conf): from rag.svr.jina_server import Prompt + if system: history.insert(0, {"role": "system", "content": system}) return Prompt(message=history, gen_conf=gen_conf) def _stream_response(self, endpoint, prompt): from rag.svr.jina_server import Generation + answer = "" try: - res = self.client.stream_doc( - on=endpoint, inputs=prompt, return_type=Generation - ) + res = self.client.stream_doc(on=endpoint, inputs=prompt, return_type=Generation) loop = asyncio.get_event_loop() try: while True: @@ -652,24 +605,24 @@ class LocalLLM(Base): class VolcEngineChat(Base): - def __init__(self, key, model_name, base_url='https://ark.cn-beijing.volces.com/api/v3'): + def __init__(self, key, model_name, base_url="https://ark.cn-beijing.volces.com/api/v3"): """ Since do not want to modify the original database fields, and the VolcEngine authentication method is quite special, Assemble ark_api_key, ep_id into api_key, store it as a dictionary type, and parse it for use model_name is for display only """ - base_url = base_url if base_url else 'https://ark.cn-beijing.volces.com/api/v3' - ark_api_key = json.loads(key).get('ark_api_key', '') - model_name = json.loads(key).get('ep_id', '') + json.loads(key).get('endpoint_id', '') + base_url = base_url if base_url else "https://ark.cn-beijing.volces.com/api/v3" + ark_api_key = json.loads(key).get("ark_api_key", "") + model_name = json.loads(key).get("ep_id", "") + json.loads(key).get("endpoint_id", "") super().__init__(ark_api_key, model_name, base_url) class MiniMaxChat(Base): def __init__( - self, - key, - model_name, - base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", + self, + key, + model_name, + base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", ): if not base_url: base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2" @@ -687,13 +640,9 @@ class MiniMaxChat(Base): "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } - payload = json.dumps( - {"model": self.model_name, "messages": history, **gen_conf} - ) + payload = json.dumps({"model": self.model_name, "messages": history, **gen_conf}) try: - response = requests.request( - "POST", url=self.base_url, headers=headers, data=payload - ) + response = requests.request("POST", url=self.base_url, headers=headers, data=payload) response = response.json() ans = response["choices"][0]["message"]["content"].strip() if response["choices"][0]["finish_reason"] == "length": @@ -737,7 +686,7 @@ class MiniMaxChat(Base): text = "" if "choices" in resp and "delta" in resp["choices"][0]: text = resp["choices"][0]["delta"]["content"] - ans += text + ans = text tol = self.total_token_count(resp) if not tol: total_tokens += num_tokens_from_string(text) @@ -752,9 +701,9 @@ class MiniMaxChat(Base): class MistralChat(Base): - def __init__(self, key, model_name, base_url=None): from mistralai.client import MistralClient + self.client = MistralClient(api_key=key) self.model_name = model_name @@ -765,10 +714,7 @@ class MistralChat(Base): if k not in ["temperature", "top_p", "max_tokens"]: del gen_conf[k] try: - response = self.client.chat( - model=self.model_name, - messages=history, - **gen_conf) + response = self.client.chat(model=self.model_name, messages=history, **gen_conf) ans = response.choices[0].message.content if response.choices[0].finish_reason == "length": if is_chinese(ans): @@ -788,14 +734,11 @@ class MistralChat(Base): ans = "" total_tokens = 0 try: - response = self.client.chat_stream( - model=self.model_name, - messages=history, - **gen_conf) + response = self.client.chat_stream(model=self.model_name, messages=history, **gen_conf) for resp in response: if not resp.choices or not resp.choices[0].delta.content: continue - ans += resp.choices[0].delta.content + ans = resp.choices[0].delta.content total_tokens += 1 if resp.choices[0].finish_reason == "length": if is_chinese(ans): @@ -811,23 +754,23 @@ class MistralChat(Base): class BedrockChat(Base): - def __init__(self, key, model_name, **kwargs): import boto3 - self.bedrock_ak = json.loads(key).get('bedrock_ak', '') - self.bedrock_sk = json.loads(key).get('bedrock_sk', '') - self.bedrock_region = json.loads(key).get('bedrock_region', '') + + self.bedrock_ak = json.loads(key).get("bedrock_ak", "") + self.bedrock_sk = json.loads(key).get("bedrock_sk", "") + self.bedrock_region = json.loads(key).get("bedrock_region", "") self.model_name = model_name - if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '': + if self.bedrock_ak == "" or self.bedrock_sk == "" or self.bedrock_region == "": # Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.) - self.client = boto3.client('bedrock-runtime') + self.client = boto3.client("bedrock-runtime") else: - self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, - aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) + self.client = boto3.client(service_name="bedrock-runtime", region_name=self.bedrock_region, aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) def chat(self, system, history, gen_conf): from botocore.exceptions import ClientError + for k in list(gen_conf.keys()): if k not in ["temperature"]: del gen_conf[k] @@ -853,6 +796,7 @@ class BedrockChat(Base): def chat_streamly(self, system, history, gen_conf): from botocore.exceptions import ClientError + for k in list(gen_conf.keys()): if k not in ["temperature"]: del gen_conf[k] @@ -860,14 +804,9 @@ class BedrockChat(Base): if not isinstance(item["content"], list) and not isinstance(item["content"], tuple): item["content"] = [{"text": item["content"]}] - if self.model_name.split('.')[0] == 'ai21': + if self.model_name.split(".")[0] == "ai21": try: - response = self.client.converse( - modelId=self.model_name, - messages=history, - inferenceConfig=gen_conf, - system=[{"text": (system if system else "Answer the user's message.")}] - ) + response = self.client.converse(modelId=self.model_name, messages=history, inferenceConfig=gen_conf, system=[{"text": (system if system else "Answer the user's message.")}]) ans = response["output"]["message"]["content"][0]["text"] return ans, num_tokens_from_string(ans) @@ -878,16 +817,13 @@ class BedrockChat(Base): try: # Send the message to the model, using a basic inference configuration. streaming_response = self.client.converse_stream( - modelId=self.model_name, - messages=history, - inferenceConfig=gen_conf, - system=[{"text": (system if system else "Answer the user's message.")}] + modelId=self.model_name, messages=history, inferenceConfig=gen_conf, system=[{"text": (system if system else "Answer the user's message.")}] ) # Extract and print the streamed response text in real-time. for resp in streaming_response["stream"]: if "contentBlockDelta" in resp: - ans += resp["contentBlockDelta"]["delta"]["text"] + ans = resp["contentBlockDelta"]["delta"]["text"] yield ans except (ClientError, Exception) as e: @@ -897,13 +833,12 @@ class BedrockChat(Base): class GeminiChat(Base): - def __init__(self, key, model_name, base_url=None): - from google.generativeai import client, GenerativeModel + from google.generativeai import GenerativeModel, client client.configure(api_key=key) _client = client.get_default_generative_client() - self.model_name = 'models/' + model_name + self.model_name = "models/" + model_name self.model = GenerativeModel(model_name=self.model_name) self.model._client = _client @@ -916,17 +851,15 @@ class GeminiChat(Base): if k not in ["temperature", "top_p", "max_tokens"]: del gen_conf[k] for item in history: - if 'role' in item and item['role'] == 'assistant': - item['role'] = 'model' - if 'role' in item and item['role'] == 'system': - item['role'] = 'user' - if 'content' in item: - item['parts'] = item.pop('content') + if "role" in item and item["role"] == "assistant": + item["role"] = "model" + if "role" in item and item["role"] == "system": + item["role"] = "user" + if "content" in item: + item["parts"] = item.pop("content") try: - response = self.model.generate_content( - history, - generation_config=gen_conf) + response = self.model.generate_content(history, generation_config=gen_conf) ans = response.text return ans, response.usage_metadata.total_token_count except Exception as e: @@ -941,17 +874,15 @@ class GeminiChat(Base): if k not in ["temperature", "top_p", "max_tokens"]: del gen_conf[k] for item in history: - if 'role' in item and item['role'] == 'assistant': - item['role'] = 'model' - if 'content' in item: - item['parts'] = item.pop('content') + if "role" in item and item["role"] == "assistant": + item["role"] = "model" + if "content" in item: + item["parts"] = item.pop("content") ans = "" try: - response = self.model.generate_content( - history, - generation_config=gen_conf, stream=True) + response = self.model.generate_content(history, generation_config=gen_conf, stream=True) for resp in response: - ans += resp.text + ans = resp.text yield ans yield response._chunks[-1].usage_metadata.total_token_count @@ -962,8 +893,9 @@ class GeminiChat(Base): class GroqChat(Base): - def __init__(self, key, model_name, base_url=''): + def __init__(self, key, model_name, base_url=""): from groq import Groq + self.client = Groq(api_key=key) self.model_name = model_name @@ -975,11 +907,7 @@ class GroqChat(Base): del gen_conf[k] ans = "" try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - **gen_conf - ) + response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf) ans = response.choices[0].message.content if response.choices[0].finish_reason == "length": if is_chinese(ans): @@ -999,16 +927,11 @@ class GroqChat(Base): ans = "" total_tokens = 0 try: - response = self.client.chat.completions.create( - model=self.model_name, - messages=history, - stream=True, - **gen_conf - ) + response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf) for resp in response: if not resp.choices or not resp.choices[0].delta.content: continue - ans += resp.choices[0].delta.content + ans = resp.choices[0].delta.content total_tokens += 1 if resp.choices[0].finish_reason == "length": if is_chinese(ans): @@ -1096,16 +1019,10 @@ class CoHereChat(Base): mes = history.pop()["message"] ans = "" try: - response = self.client.chat( - model=self.model_name, chat_history=history, message=mes, **gen_conf - ) + response = self.client.chat(model=self.model_name, chat_history=history, message=mes, **gen_conf) ans = response.text if response.finish_reason == "MAX_TOKENS": - ans += ( - "...\nFor the content length reason, it stopped, continue?" - if is_english([ans]) - else "······\n由于长度的原因,回答被截断了,要继续吗?" - ) + ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" return ( ans, response.meta.tokens.input_tokens + response.meta.tokens.output_tokens, @@ -1133,20 +1050,14 @@ class CoHereChat(Base): ans = "" total_tokens = 0 try: - response = self.client.chat_stream( - model=self.model_name, chat_history=history, message=mes, **gen_conf - ) + response = self.client.chat_stream(model=self.model_name, chat_history=history, message=mes, **gen_conf) for resp in response: if resp.event_type == "text-generation": - ans += resp.text + ans = resp.text total_tokens += num_tokens_from_string(resp.text) elif resp.event_type == "stream-end": if resp.finish_reason == "MAX_TOKENS": - ans += ( - "...\nFor the content length reason, it stopped, continue?" - if is_english([ans]) - else "······\n由于长度的原因,回答被截断了,要继续吗?" - ) + ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" yield ans except Exception as e: @@ -1217,9 +1128,7 @@ class ReplicateChat(Base): del gen_conf["max_tokens"] if system: self.system = system - prompt = "\n".join( - [item["role"] + ":" + item["content"] for item in history[-5:]] - ) + prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]]) ans = "" try: response = self.client.run( @@ -1236,9 +1145,7 @@ class ReplicateChat(Base): del gen_conf["max_tokens"] if system: self.system = system - prompt = "\n".join( - [item["role"] + ":" + item["content"] for item in history[-5:]] - ) + prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]]) ans = "" try: response = self.client.run( @@ -1246,7 +1153,7 @@ class ReplicateChat(Base): input={"system_prompt": self.system, "prompt": prompt, **gen_conf}, ) for resp in response: - ans += resp + ans = resp yield ans except Exception as e: @@ -1268,10 +1175,10 @@ class HunyuanChat(Base): self.client = hunyuan_client.HunyuanClient(cred, "") def chat(self, system, history, gen_conf): - from tencentcloud.hunyuan.v20230901 import models from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( TencentCloudSDKException, ) + from tencentcloud.hunyuan.v20230901 import models _gen_conf = {} _history = [{k.capitalize(): v for k, v in item.items()} for item in history] @@ -1296,10 +1203,10 @@ class HunyuanChat(Base): return ans + "\n**ERROR**: " + str(e), 0 def chat_streamly(self, system, history, gen_conf): - from tencentcloud.hunyuan.v20230901 import models from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( TencentCloudSDKException, ) + from tencentcloud.hunyuan.v20230901 import models _gen_conf = {} _history = [{k.capitalize(): v for k, v in item.items()} for item in history] @@ -1327,7 +1234,7 @@ class HunyuanChat(Base): resp = json.loads(resp["data"]) if not resp["Choices"] or not resp["Choices"][0]["Delta"]["Content"]: continue - ans += resp["Choices"][0]["Delta"]["Content"] + ans = resp["Choices"][0]["Delta"]["Content"] total_tokens += 1 yield ans @@ -1339,9 +1246,7 @@ class HunyuanChat(Base): class SparkChat(Base): - def __init__( - self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1" - ): + def __init__(self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"): if not base_url: base_url = "https://spark-api-open.xf-yun.com/v1" model2version = { @@ -1374,22 +1279,14 @@ class BaiduYiyanChat(Base): def chat(self, system, history, gen_conf): if system: self.system = system - gen_conf["penalty_score"] = ( - (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", - 0)) / 2 - ) + 1 + gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1 if "max_tokens" in gen_conf: del gen_conf["max_tokens"] ans = "" try: - response = self.client.do( - model=self.model_name, - messages=history, - system=self.system, - **gen_conf - ).body - ans = response['result'] + response = self.client.do(model=self.model_name, messages=history, system=self.system, **gen_conf).body + ans = response["result"] return ans, self.total_token_count(response) except Exception as e: @@ -1398,26 +1295,17 @@ class BaiduYiyanChat(Base): def chat_streamly(self, system, history, gen_conf): if system: self.system = system - gen_conf["penalty_score"] = ( - (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", - 0)) / 2 - ) + 1 + gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1 if "max_tokens" in gen_conf: del gen_conf["max_tokens"] ans = "" total_tokens = 0 try: - response = self.client.do( - model=self.model_name, - messages=history, - system=self.system, - stream=True, - **gen_conf - ) + response = self.client.do(model=self.model_name, messages=history, system=self.system, stream=True, **gen_conf) for resp in response: resp = resp.body - ans += resp['result'] + ans = resp["result"] total_tokens = self.total_token_count(resp) yield ans @@ -1458,11 +1346,7 @@ class AnthropicChat(Base): ).to_dict() ans = response["content"][0]["text"] if response["stop_reason"] == "max_tokens": - ans += ( - "...\nFor the content length reason, it stopped, continue?" - if is_english([ans]) - else "······\n由于长度的原因,回答被截断了,要继续吗?" - ) + ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" return ( ans, response["usage"]["input_tokens"] + response["usage"]["output_tokens"], @@ -1483,6 +1367,7 @@ class AnthropicChat(Base): ans = "" total_tokens = 0 + reasoning_start = False try: response = self.client.messages.create( model=self.model_name, @@ -1492,15 +1377,17 @@ class AnthropicChat(Base): **gen_conf, ) for res in response: - if res.type == 'content_block_delta': + if res.type == "content_block_delta": if res.delta.type == "thinking_delta" and res.delta.thinking: - if ans.find("") < 0: - ans += "" - ans = ans.replace("", "") + ans = "" + if not reasoning_start: + reasoning_start = True + ans = "" ans += res.delta.thinking + "" else: + reasoning_start = False text = res.delta.text - ans += text + ans = text total_tokens += num_tokens_from_string(text) yield ans except Exception as e: @@ -1511,13 +1398,12 @@ class AnthropicChat(Base): class GoogleChat(Base): def __init__(self, key, model_name, base_url=None): - from google.oauth2 import service_account import base64 + from google.oauth2 import service_account + key = json.loads(key) - access_token = json.loads( - base64.b64decode(key.get("google_service_account_key", "")) - ) + access_token = json.loads(base64.b64decode(key.get("google_service_account_key", ""))) project_id = key.get("google_project_id", "") region = key.get("google_region", "") @@ -1530,28 +1416,20 @@ class GoogleChat(Base): from google.auth.transport.requests import Request if access_token: - credits = service_account.Credentials.from_service_account_info( - access_token, scopes=scopes - ) + credits = service_account.Credentials.from_service_account_info(access_token, scopes=scopes) request = Request() credits.refresh(request) token = credits.token - self.client = AnthropicVertex( - region=region, project_id=project_id, access_token=token - ) + self.client = AnthropicVertex(region=region, project_id=project_id, access_token=token) else: self.client = AnthropicVertex(region=region, project_id=project_id) else: - from google.cloud import aiplatform import vertexai.generative_models as glm + from google.cloud import aiplatform if access_token: - credits = service_account.Credentials.from_service_account_info( - access_token - ) - aiplatform.init( - credentials=credits, project=project_id, location=region - ) + credits = service_account.Credentials.from_service_account_info(access_token) + aiplatform.init(credentials=credits, project=project_id, location=region) else: aiplatform.init(project=project_id, location=region) self.client = glm.GenerativeModel(model_name=self.model_name) @@ -1573,15 +1451,10 @@ class GoogleChat(Base): ).json() ans = response["content"][0]["text"] if response["stop_reason"] == "max_tokens": - ans += ( - "...\nFor the content length reason, it stopped, continue?" - if is_english([ans]) - else "······\n由于长度的原因,回答被截断了,要继续吗?" - ) + ans += "...\nFor the content length reason, it stopped, continue?" if is_english([ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" return ( ans, - response["usage"]["input_tokens"] - + response["usage"]["output_tokens"], + response["usage"]["input_tokens"] + response["usage"]["output_tokens"], ) except Exception as e: return "\n**ERROR**: " + str(e), 0 @@ -1598,9 +1471,7 @@ class GoogleChat(Base): if "content" in item: item["parts"] = item.pop("content") try: - response = self.client.generate_content( - history, generation_config=gen_conf - ) + response = self.client.generate_content(history, generation_config=gen_conf) ans = response.text return ans, response.usage_metadata.total_token_count except Exception as e: @@ -1627,7 +1498,7 @@ class GoogleChat(Base): res = res.decode("utf-8") if "content_block_delta" in res and "data" in res: text = json.loads(res[6:])["delta"]["text"] - ans += text + ans = text total_tokens += num_tokens_from_string(text) except Exception as e: yield ans + "\n**ERROR**: " + str(e) @@ -1647,11 +1518,9 @@ class GoogleChat(Base): item["parts"] = item.pop("content") ans = "" try: - response = self.model.generate_content( - history, generation_config=gen_conf, stream=True - ) + response = self.model.generate_content(history, generation_config=gen_conf, stream=True) for resp in response: - ans += resp.text + ans = resp.text yield ans except Exception as e: