mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat: Google drive supports web-based credentials (#11173)
### What problem does this PR solve? Google drive supports web-based credentials. <img width="1204" height="612" alt="image" src="https://github.com/user-attachments/assets/70291c63-a2dd-4a80-ae20-807fe034cdbc" /> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -13,16 +13,26 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from html import escape
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required, current_user
|
||||
from flask import make_response, request
|
||||
from flask_login import current_user, login_required
|
||||
from google_auth_oauthlib.flow import Flow
|
||||
|
||||
from api.db import InputType
|
||||
from api.db.services.connector_service import ConnectorService, SyncLogsService
|
||||
from api.utils.api_utils import get_json_result, validate_request, get_data_error_result
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, validate_request
|
||||
from common.constants import RetCode, TaskStatus
|
||||
from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, DocumentSource
|
||||
from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES
|
||||
from common.misc_utils import get_uuid
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
@manager.route("/set", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@ -42,8 +52,8 @@ def set_connector():
|
||||
"config": req["config"],
|
||||
"refresh_freq": int(req.get("refresh_freq", 30)),
|
||||
"prune_freq": int(req.get("prune_freq", 720)),
|
||||
"timeout_secs": int(req.get("timeout_secs", 60*29)),
|
||||
"status": TaskStatus.SCHEDULE
|
||||
"timeout_secs": int(req.get("timeout_secs", 60 * 29)),
|
||||
"status": TaskStatus.SCHEDULE,
|
||||
}
|
||||
conn["status"] = TaskStatus.SCHEDULE
|
||||
ConnectorService.save(**conn)
|
||||
@ -105,3 +115,181 @@ def rm_connector(connector_id):
|
||||
ConnectorService.resume(connector_id, TaskStatus.CANCEL)
|
||||
ConnectorService.delete_by_id(connector_id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
GOOGLE_WEB_FLOW_STATE_PREFIX = "google_drive_web_flow_state"
|
||||
GOOGLE_WEB_FLOW_RESULT_PREFIX = "google_drive_web_flow_result"
|
||||
WEB_FLOW_TTL_SECS = 15 * 60
|
||||
|
||||
|
||||
def _web_state_cache_key(flow_id: str) -> str:
|
||||
return f"{GOOGLE_WEB_FLOW_STATE_PREFIX}:{flow_id}"
|
||||
|
||||
|
||||
def _web_result_cache_key(flow_id: str) -> str:
|
||||
return f"{GOOGLE_WEB_FLOW_RESULT_PREFIX}:{flow_id}"
|
||||
|
||||
|
||||
def _load_credentials(payload: str | dict[str, Any]) -> dict[str, Any]:
|
||||
if isinstance(payload, dict):
|
||||
return payload
|
||||
try:
|
||||
return json.loads(payload)
|
||||
except json.JSONDecodeError as exc: # pragma: no cover - defensive
|
||||
raise ValueError("Invalid Google credentials JSON.") from exc
|
||||
|
||||
|
||||
def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]:
|
||||
web_section = credentials.get("web")
|
||||
if not isinstance(web_section, dict):
|
||||
raise ValueError("Google OAuth JSON must include a 'web' client configuration to use browser-based authorization.")
|
||||
return {"web": web_section}
|
||||
|
||||
|
||||
def _render_web_oauth_popup(flow_id: str, success: bool, message: str):
|
||||
status = "success" if success else "error"
|
||||
auto_close = "window.close();" if success else ""
|
||||
escaped_message = escape(message)
|
||||
payload_json = json.dumps(
|
||||
{
|
||||
"type": "ragflow-google-drive-oauth",
|
||||
"status": status,
|
||||
"flowId": flow_id or "",
|
||||
"message": message,
|
||||
}
|
||||
)
|
||||
html = GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE.format(
|
||||
heading="Authorization complete" if success else "Authorization failed",
|
||||
message=escaped_message,
|
||||
payload_json=payload_json,
|
||||
auto_close=auto_close,
|
||||
)
|
||||
response = make_response(html, 200)
|
||||
response.headers["Content-Type"] = "text/html; charset=utf-8"
|
||||
return response
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("credentials")
|
||||
def start_google_drive_web_oauth():
|
||||
if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI:
|
||||
return get_json_result(
|
||||
code=RetCode.SERVER_ERROR,
|
||||
message="Google Drive OAuth redirect URI is not configured on the server.",
|
||||
)
|
||||
|
||||
req = request.json or {}
|
||||
raw_credentials = req.get("credentials", "")
|
||||
try:
|
||||
credentials = _load_credentials(raw_credentials)
|
||||
except ValueError as exc:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc))
|
||||
|
||||
if credentials.get("refresh_token"):
|
||||
return get_json_result(
|
||||
code=RetCode.ARGUMENT_ERROR,
|
||||
message="Uploaded credentials already include a refresh token.",
|
||||
)
|
||||
|
||||
try:
|
||||
client_config = _get_web_client_config(credentials)
|
||||
except ValueError as exc:
|
||||
return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc))
|
||||
|
||||
flow_id = str(uuid.uuid4())
|
||||
try:
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
||||
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
||||
authorization_url, _ = flow.authorization_url(
|
||||
access_type="offline",
|
||||
include_granted_scopes="true",
|
||||
prompt="consent",
|
||||
state=flow_id,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.exception("Failed to create Google OAuth flow: %s", exc)
|
||||
return get_json_result(
|
||||
code=RetCode.SERVER_ERROR,
|
||||
message="Failed to initialize Google OAuth flow. Please verify the uploaded client configuration.",
|
||||
)
|
||||
|
||||
cache_payload = {
|
||||
"user_id": current_user.id,
|
||||
"client_config": client_config,
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
REDIS_CONN.set_obj(_web_state_cache_key(flow_id), cache_payload, WEB_FLOW_TTL_SECS)
|
||||
|
||||
return get_json_result(
|
||||
data={
|
||||
"flow_id": flow_id,
|
||||
"authorization_url": authorization_url,
|
||||
"expires_in": WEB_FLOW_TTL_SECS,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821
|
||||
def google_drive_web_oauth_callback():
|
||||
state_id = request.args.get("state")
|
||||
error = request.args.get("error")
|
||||
error_description = request.args.get("error_description") or error
|
||||
|
||||
if not state_id:
|
||||
return _render_web_oauth_popup("", False, "Missing OAuth state parameter.")
|
||||
|
||||
state_cache = REDIS_CONN.get(_web_state_cache_key(state_id))
|
||||
if not state_cache:
|
||||
return _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.")
|
||||
|
||||
state_obj = json.loads(state_cache)
|
||||
client_config = state_obj.get("client_config")
|
||||
if not client_config:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.")
|
||||
|
||||
if error:
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.")
|
||||
|
||||
code = request.args.get("code")
|
||||
if not code:
|
||||
return _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.")
|
||||
|
||||
try:
|
||||
flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE])
|
||||
flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI
|
||||
flow.fetch_token(code=code)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logging.exception("Failed to exchange Google OAuth code: %s", exc)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
return _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.")
|
||||
|
||||
creds_json = flow.credentials.to_json()
|
||||
result_payload = {
|
||||
"user_id": state_obj.get("user_id"),
|
||||
"credentials": creds_json,
|
||||
}
|
||||
REDIS_CONN.set_obj(_web_result_cache_key(state_id), result_payload, WEB_FLOW_TTL_SECS)
|
||||
REDIS_CONN.delete(_web_state_cache_key(state_id))
|
||||
|
||||
return _render_web_oauth_popup(state_id, True, "Authorization completed successfully.")
|
||||
|
||||
|
||||
@manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("flow_id")
|
||||
def poll_google_drive_web_result():
|
||||
req = request.json or {}
|
||||
flow_id = req.get("flow_id")
|
||||
cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id))
|
||||
if not cache_raw:
|
||||
return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.")
|
||||
|
||||
result = json.loads(cache_raw)
|
||||
if result.get("user_id") != current_user.id:
|
||||
return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.")
|
||||
|
||||
REDIS_CONN.delete(_web_result_cache_key(flow_id))
|
||||
return get_json_result(data={"credentials": result.get("credentials")})
|
||||
|
||||
Reference in New Issue
Block a user