diff --git a/api/apps/connector_app.py b/api/apps/connector_app.py index ba78d6d6f..23965e617 100644 --- a/api/apps/connector_app.py +++ b/api/apps/connector_app.py @@ -13,16 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import json +import logging import time +import uuid +from html import escape +from typing import Any -from flask import request -from flask_login import login_required, current_user +from flask import make_response, request +from flask_login import current_user, login_required +from google_auth_oauthlib.flow import Flow from api.db import InputType from api.db.services.connector_service import ConnectorService, SyncLogsService -from api.utils.api_utils import get_json_result, validate_request, get_data_error_result -from common.misc_utils import get_uuid +from api.utils.api_utils import get_data_error_result, get_json_result, validate_request from common.constants import RetCode, TaskStatus +from common.data_source.config import GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, DocumentSource +from common.data_source.google_util.constant import GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE, GOOGLE_SCOPES +from common.misc_utils import get_uuid +from rag.utils.redis_conn import REDIS_CONN + @manager.route("/set", methods=["POST"]) # noqa: F821 @login_required @@ -42,8 +52,8 @@ def set_connector(): "config": req["config"], "refresh_freq": int(req.get("refresh_freq", 30)), "prune_freq": int(req.get("prune_freq", 720)), - "timeout_secs": int(req.get("timeout_secs", 60*29)), - "status": TaskStatus.SCHEDULE + "timeout_secs": int(req.get("timeout_secs", 60 * 29)), + "status": TaskStatus.SCHEDULE, } conn["status"] = TaskStatus.SCHEDULE ConnectorService.save(**conn) @@ -105,3 +115,181 @@ def rm_connector(connector_id): ConnectorService.resume(connector_id, TaskStatus.CANCEL) ConnectorService.delete_by_id(connector_id) return get_json_result(data=True) + + +GOOGLE_WEB_FLOW_STATE_PREFIX = "google_drive_web_flow_state" +GOOGLE_WEB_FLOW_RESULT_PREFIX = "google_drive_web_flow_result" +WEB_FLOW_TTL_SECS = 15 * 60 + + +def _web_state_cache_key(flow_id: str) -> str: + return f"{GOOGLE_WEB_FLOW_STATE_PREFIX}:{flow_id}" + + +def _web_result_cache_key(flow_id: str) -> str: + return f"{GOOGLE_WEB_FLOW_RESULT_PREFIX}:{flow_id}" + + +def _load_credentials(payload: str | dict[str, Any]) -> dict[str, Any]: + if isinstance(payload, dict): + return payload + try: + return json.loads(payload) + except json.JSONDecodeError as exc: # pragma: no cover - defensive + raise ValueError("Invalid Google credentials JSON.") from exc + + +def _get_web_client_config(credentials: dict[str, Any]) -> dict[str, Any]: + web_section = credentials.get("web") + if not isinstance(web_section, dict): + raise ValueError("Google OAuth JSON must include a 'web' client configuration to use browser-based authorization.") + return {"web": web_section} + + +def _render_web_oauth_popup(flow_id: str, success: bool, message: str): + status = "success" if success else "error" + auto_close = "window.close();" if success else "" + escaped_message = escape(message) + payload_json = json.dumps( + { + "type": "ragflow-google-drive-oauth", + "status": status, + "flowId": flow_id or "", + "message": message, + } + ) + html = GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE.format( + heading="Authorization complete" if success else "Authorization failed", + message=escaped_message, + payload_json=payload_json, + auto_close=auto_close, + ) + response = make_response(html, 200) + response.headers["Content-Type"] = "text/html; charset=utf-8" + return response + + +@manager.route("/google-drive/oauth/web/start", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("credentials") +def start_google_drive_web_oauth(): + if not GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI: + return get_json_result( + code=RetCode.SERVER_ERROR, + message="Google Drive OAuth redirect URI is not configured on the server.", + ) + + req = request.json or {} + raw_credentials = req.get("credentials", "") + try: + credentials = _load_credentials(raw_credentials) + except ValueError as exc: + return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc)) + + if credentials.get("refresh_token"): + return get_json_result( + code=RetCode.ARGUMENT_ERROR, + message="Uploaded credentials already include a refresh token.", + ) + + try: + client_config = _get_web_client_config(credentials) + except ValueError as exc: + return get_json_result(code=RetCode.ARGUMENT_ERROR, message=str(exc)) + + flow_id = str(uuid.uuid4()) + try: + flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]) + flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI + authorization_url, _ = flow.authorization_url( + access_type="offline", + include_granted_scopes="true", + prompt="consent", + state=flow_id, + ) + except Exception as exc: # pragma: no cover - defensive + logging.exception("Failed to create Google OAuth flow: %s", exc) + return get_json_result( + code=RetCode.SERVER_ERROR, + message="Failed to initialize Google OAuth flow. Please verify the uploaded client configuration.", + ) + + cache_payload = { + "user_id": current_user.id, + "client_config": client_config, + "created_at": int(time.time()), + } + REDIS_CONN.set_obj(_web_state_cache_key(flow_id), cache_payload, WEB_FLOW_TTL_SECS) + + return get_json_result( + data={ + "flow_id": flow_id, + "authorization_url": authorization_url, + "expires_in": WEB_FLOW_TTL_SECS, + } + ) + + +@manager.route("/google-drive/oauth/web/callback", methods=["GET"]) # noqa: F821 +def google_drive_web_oauth_callback(): + state_id = request.args.get("state") + error = request.args.get("error") + error_description = request.args.get("error_description") or error + + if not state_id: + return _render_web_oauth_popup("", False, "Missing OAuth state parameter.") + + state_cache = REDIS_CONN.get(_web_state_cache_key(state_id)) + if not state_cache: + return _render_web_oauth_popup(state_id, False, "Authorization session expired. Please restart from the main window.") + + state_obj = json.loads(state_cache) + client_config = state_obj.get("client_config") + if not client_config: + REDIS_CONN.delete(_web_state_cache_key(state_id)) + return _render_web_oauth_popup(state_id, False, "Authorization session was invalid. Please retry.") + + if error: + REDIS_CONN.delete(_web_state_cache_key(state_id)) + return _render_web_oauth_popup(state_id, False, error_description or "Authorization was cancelled.") + + code = request.args.get("code") + if not code: + return _render_web_oauth_popup(state_id, False, "Missing authorization code from Google.") + + try: + flow = Flow.from_client_config(client_config, scopes=GOOGLE_SCOPES[DocumentSource.GOOGLE_DRIVE]) + flow.redirect_uri = GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI + flow.fetch_token(code=code) + except Exception as exc: # pragma: no cover - defensive + logging.exception("Failed to exchange Google OAuth code: %s", exc) + REDIS_CONN.delete(_web_state_cache_key(state_id)) + return _render_web_oauth_popup(state_id, False, "Failed to exchange tokens with Google. Please retry.") + + creds_json = flow.credentials.to_json() + result_payload = { + "user_id": state_obj.get("user_id"), + "credentials": creds_json, + } + REDIS_CONN.set_obj(_web_result_cache_key(state_id), result_payload, WEB_FLOW_TTL_SECS) + REDIS_CONN.delete(_web_state_cache_key(state_id)) + + return _render_web_oauth_popup(state_id, True, "Authorization completed successfully.") + + +@manager.route("/google-drive/oauth/web/result", methods=["POST"]) # noqa: F821 +@login_required +@validate_request("flow_id") +def poll_google_drive_web_result(): + req = request.json or {} + flow_id = req.get("flow_id") + cache_raw = REDIS_CONN.get(_web_result_cache_key(flow_id)) + if not cache_raw: + return get_json_result(code=RetCode.RUNNING, message="Authorization is still pending.") + + result = json.loads(cache_raw) + if result.get("user_id") != current_user.id: + return get_json_result(code=RetCode.PERMISSION_ERROR, message="You are not allowed to access this authorization result.") + + REDIS_CONN.delete(_web_result_cache_key(flow_id)) + return get_json_result(data={"credentials": result.get("credentials")}) diff --git a/common/data_source/config.py b/common/data_source/config.py index 5eea4f6e4..196d9ed3e 100644 --- a/common/data_source/config.py +++ b/common/data_source/config.py @@ -190,6 +190,7 @@ OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "" OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get( "OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", "" ) +GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = os.environ.get("GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "http://localhost:9380/v1/connector/google-drive/oauth/web/callback") CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token" RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower() diff --git a/common/data_source/google_util/constant.py b/common/data_source/google_util/constant.py index c0b7f0711..8ab75fa14 100644 --- a/common/data_source/google_util/constant.py +++ b/common/data_source/google_util/constant.py @@ -47,3 +47,57 @@ USER_FIELDS = "nextPageToken, users(primaryEmail)" # Error message substrings MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested" SCOPE_INSTRUCTIONS = "" + + +GOOGLE_DRIVE_WEB_OAUTH_POPUP_TEMPLATE = """ + +
+ +{message}
+You can close this window.
+