0%

python httpx-oauth 源码分析

一直想学习下 OAuth2(Open Authorization) 协议,最近看 fastapi-users 的源代码,发现它们使用 HTTPX OAuth 来提供 OAuth2 授权支持,于是就顺便研究了下 httpx-oauth 的源码。

OAuth2 协议简介

OAuth 2.0 是一个授权框架 (Authorization Framework),它的核心目标是让用户(资源所有者)授权第三方应用程序(客户端)访问其在服务提供商(授权服务器/资源服务器)上存储的受保护资源,而无需将用户的凭证(如密码)提供给第三方应用程序。也就是说,它解决的一个本质问题是:让第三方应用,在不拿到你密码的情况下,安全地访问你的资源

核心概念

OAuth 2.0 协议定义了四个核心角色,它们协同工作完成授权流程:

角色名称 职责描述
资源所有者用户本人(Resource Owner) 用户本人,拥有受保护资源和授权访问的权利
客户端(Client) 第三方应用程序,该应用程序代表资源所有者,请求访问受保护资源(例如:一个第三方 web 应用,请求使用 google 账户登录)
授权服务器(Authorization Server) 颁发令牌的服务器,负责验证资源所有者的身份,并在授权通过后,向客户端颁发访问令牌 (Access Token)
资源服务器(Resource Server) 存储受保护数据的服务器,负责托管受保护资源,并接收和验证客户端携带的 Access Token,决定是否给予访问

核心流程

OAuth 2.0 定义了多种授权模式,以适应不同的客户端类型和安全需求。接下来以最常用的授权码模式为例,介绍 OAuth 2.0 的核心流程,为了讲解方便,这个流程以 Github 授权服务 为例。为了让你的 Web 应用能够使用 GitHub 登录,你需要先在 GitHub 应用注册页面 上为你的 Web 应用创建一个 OAuth App:

  • 步骤 1: 客户端应用请求授权

    • 用户在客户端应用(即第三方应用)上开始请求 Github 授权服务器 授权,例如点击 使用 GitHub 登录 按钮
    • 此时客户端应用将用户重定向到 Github 授权服务器的 授权 URL
    • 该重定向请求中会包含以下重要参数:
      • client_id:客户端应用在授权服务器注册时获得的唯一标识符
      • redirect_uri:授权成功或失败后,授权服务器将用户重定向回来的 URL(必须与注册时填写的 URL 匹配)
      • scope:客户端应用希望获得的权限范围(例如,读取用户基本信息 read:user
      • state:一个随机字符串,用于防止跨站请求伪造 (CSRF) 攻击
  • 步骤 2: 资源所有者授予权限

    • 授权服务器(GitHub)验证用户的身份(要求用户输入 GitHub 账号和密码)
    • GitHub 向用户展示一个授权页面,说明客户端应用请求访问哪些权限(即 scope)
    • 资源所有者(用户)点击 授权 按钮,同意授予权限
  • 步骤 3:授权服务器返回授权码

    • 授权服务器(GitHub)将用户重定向回客户端应用注册时填充的 redirect_uri
    • 重定向 URL 中包含一个一次性、短暂有效的授权码 (Authorization Code) 和原先的 state 参数(客户端应用需要检查返回的 state 值是否与请求时所传递的值一致,以防止 CSRF 攻击)
  • 步骤 4: 客户端应用获取访问令牌(Access Token)

    • 客户端应用接收到授权码后,立即使用该码授权向 Github 授权服务器的 令牌交换 endpoint 发送 POST 请求,这个请求包含如下重要参数:
      • client_id:客户端应用 ID
      • client_secret:客户端应用在 GitHub 注册时获得的密钥(这是确保请求真实性的关键,必须保存在后端,绝不能泄露给前端)
      • code:上一步所获得的授权码
      • redirect_uri:必须与发送授权请求时所使用的 redirect_uri 参数保持一致。授权服务器用它来验证这个请求的合法性,这确保了只有最开始请求授权码的那个 URI 才有资格在第二步进行令牌交换,这极大地限制了攻击者窃取授权码并成功换取访问令牌的可能性
    • 这个请求发生在客户端应用的服务器端,而不是前端。这是因为 client_secret 必须保密,绝不能泄露给前端用户或第三方应用
  • 步骤 5: 授权服务器颁发访问令牌

    • GitHub 授权服务器验证 client_id、client_secret、code 等参数的有效性
    • 验证成功后,GitHub 向客户端应用返回一个 JSON 响应,其中包含:
      • 访问令牌 (Access Token):一个用于访问资源服务器(GitHub API)的凭证
      • token_type:令牌类型(通常是 Bearer)
      • expires_in (可选):令牌的有效期
    • 标准 OAuth 2.0 规范还允许同时返回一个 Refresh Token。如果授权服务器颁发了 Refresh Token,客户端可以利用它在 Access Token 过期后获取新的 Access Token,而无需用户重新授权
  • 步骤 6:客户端应用访问受保护资源

    • 客户端应用使用获得的访问令牌,向资源服务器(GitHub API)发送请求,例如获取用户用户名、头像等基本信息
    • 请求时,访问令牌通常放在 HTTP 请求头的 Authorization 字段中,格式为:Authorization: Bearer [Access Token]
  • 步骤 7:资源服务器返回资源

    • 资源服务器(GitHub API)验证访问令牌的有效性
    • 如果令牌有效且拥有足够的权限(在 scope 定义范围内),API 将返回用户请求的资源,例如用户的 Github 基本用户信息
    • 客户端应用通常会使用这些数据在自己的系统中创建或登录一个用户账户,完成授权登录流程

OAuth2 中的 CSRF 攻击

上述流程中,说到 OAuth2 中的 state 参数是用于预防 CSRF 攻击的。那 OAuth2 中的 CSRF 攻击是怎样的呢:

  • 攻击者在自己的浏览器上开始 OAuth 流程(比如点击 使用 Github 登录
  • 当 Github 返回重定向请求(包含返回给攻击者的授权 Code)时,攻击者拦截这个请求,不让自己的浏览器发给应用服务器
  • 攻击者拿到了一个有效的 URL,类似于:https://app.com/callback?code=attacker_code
  • 攻击者通过邮件、恶意链接或隐藏的 <img> 标签,诱骗受害者点击这个链接(或者让受害者的浏览器自动加载这个 URL)
  • 受害者的浏览器向应用服务器发起了请求:GET /callback?code=attacker_code,但当前身份 cookie 等信息是受害者的
  • 应用服务器认为这是一个合法的请求(因为它确实有有效的 Code),应用服务器用 attacker_code 去 Github 换取 Token,Github 返回应用服务器 攻击者(Attacker)的 Token
  • 应用服务器为受害者建立 Session,但是,这个 Session 对应的是攻击者的 Github 账号

而通过 state 参数,应用服务器通过检查 回调 URL 中的 state 参数与 发起请求时的 state 参数 是否一致,可以有效地防止这种攻击。

httpx-oauth 源码分析

在学习了 OAuth2 协议之后,我们再来分析下 httpx-oauth 的源码。httpx-oauth 是基于 httpx 的异步 OAuth2 客户端实现。

httpx-oauth 的设计思路非常清晰:一个基类 + 多个预定义子类

  • BaseOAuth2(基类):可以认为是 httpx-oauth 库的核心数据结构,定义了实现 OAuth2 Client 的通用逻辑,包括 生成授权 URL获取 Access Token 等流程
  • 具体 Provider(预定义子类):GitHubOAuth2GoogleOAuth2 等具体的 OAuth2 Client 实现类,提供了针对不同 OAuth2 服务商(例如 Github、Google)的特定 Client 实现。库的使用者可以直接使用这些预定义的 OAuth client 来进行 OAuth2 认证授权

下图则展示了 http-oauth 的源码文件结构:

1
2
3
4
5
6
7
8
9
# tree --dirsfirst -L 1 .
.
├── clients
├── integrations
├── branding.py
├── exceptions.py
├── __init__.py
├── oauth2.py
└── py.typed
  • clients 目录:存放具体的 OAuth2 Client 实现类
  • integrations 目录:提供了与 FastAPI 集成时的一些 helper 功能
  • branding.py:定义了一个 BrandingProtocol Protocol 类,用于表示所有具有的 OAuth2 Client 都需要遵守的协议
  • exceptions.py:定义 http-oauth 库抛出的异常
  • oauth2.py:定义了 BaseOAuth2 这个核心基类
  • __init__.pyhttp-oauth 包的 __init__ 文件
  • py.typed:告诉类型检查器,这个 Python 包是 带完整类型信息的

BaseOAuth2

BaseOAuth2 类是 httpx-oauth 库的核心,它定义了实现 OAuth2 Client 的通用逻辑。它的定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class BaseOAuth2(Generic[T]):
name: str
client_id: str
client_secret: str
authorize_endpoint: str
access_token_endpoint: str
refresh_token_endpoint: Optional[str]
revoke_token_endpoint: Optional[str]
base_scopes: Optional[list[str]]
token_endpoint_auth_method: OAuth2ClientAuthMethod
revocation_endpoint_auth_method: Optional[OAuth2ClientAuthMethod]
request_headers: dict[str, str]

def __init__(
self,
client_id: str,
client_secret: str,
authorize_endpoint: str,
access_token_endpoint: str,
refresh_token_endpoint: Optional[str] = None,
revoke_token_endpoint: Optional[str] = None,
*,
name: str = "oauth2",
base_scopes: Optional[list[str]] = None,
token_endpoint_auth_method: OAuth2ClientAuthMethod = "client_secret_post",
revocation_endpoint_auth_method: Optional[OAuth2ClientAuthMethod] = None,
):
# 参数的校验
_check_valid_auth_method(token_endpoint_auth_method)
if revocation_endpoint_auth_method is not None:
_check_valid_auth_method(revocation_endpoint_auth_method)
if (
revoke_token_endpoint is not None
and revocation_endpoint_auth_method is None
):
raise MissingRevokeTokenAuthMethodError()

# 设置 client 的各个属性
self.client_id = client_id
self.client_secret = client_secret
self.authorize_endpoint = authorize_endpoint
self.access_token_endpoint = access_token_endpoint
self.refresh_token_endpoint = refresh_token_endpoint
self.revoke_token_endpoint = revoke_token_endpoint
self.name = name
self.base_scopes = base_scopes
self.token_endpoint_auth_method = token_endpoint_auth_method
self.revocation_endpoint_auth_method = revocation_endpoint_auth_method

self.request_headers = {
"Accept": "application/json",
}
  • BaseOAuth2 是一个泛型类,用 T 类型变量来表示构建授权 URL 时的额外参数的类型(下文可以看到)
  • 在构造函数中,完成了对部分参数的校验
  • 之后设置了 BaseOAuth2 类的各个属性,这些属性都是 OAuth2 协议中的关键参数
    • client_id、client_secret 是注册应用时由 OAuth2 服务商提供的标识符和密钥
    • authorize_endpoint、access_token_endpoint 等是 OAuth2 服务商所对应的 授权 endpointAccess Token endpoint

虽然 token_endpoint_auth_method 等参数使用了类型注解 OAuth2ClientAuthMethod,但是类型注解只在代码静态分析时有用,运行时并不会做任何校验。_check_valid_auth_method 函数用于传入的认证方法参数是否合法:

1
2
3
4
5
OAuth2ClientAuthMethod = Literal["client_secret_basic", "client_secret_post"]

def _check_valid_auth_method(auth_method: str) -> None:
if auth_method not in get_args(OAuth2ClientAuthMethod):
raise NotSupportedAuthMethodError(auth_method)
  • auth_method 必须是 client_secret_basicclient_secret_post 字符串字面量之一
  • 这里使用了 python typing 库的 get_args 函数,用于获取 Literal 的所有可能的值。
    • get_args() 是 Python 标准库 typing 模块中的一个运行时函数,用于提取泛型类型(Generic Types)或特殊类型(如 Literal, Union 等)的内部参数
    • 与之对应的,还有一个 get_origin 函数用于获取 外层类型
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
>>> AuthMethod = Literal["basic", "post", "none"]
>>> print(get_args(AuthMethod))
('basic', 'post', 'none')

>>> U = int | str
>>> print(get_args(U))
(<class 'int'>, <class 'str'>)

>>> print(get_args(list[str]))
(<class 'str'>,)

>>> get_origin(list[str, int])
<class 'list'>
>>> get_args(list[str, int])
(<class 'str'>, <class 'int'>)

BaseOAuth2 中的 get_authorization_url 方法用于构建授权 URL,用户需要被重定向到这个函数所返回的 URL 上开始 OAuth2 授权流程,即引导用户跳转到第三方平台(如 GitHub、Google)进行登录授权。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
async def get_authorization_url(
self,
redirect_uri: str,
state: Optional[str] = None,
scope: Optional[list[str]] = None,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[Literal["plain", "S256"]] = None,
extras_params: Optional[T] = None,
) -> str:
params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": redirect_uri,
}

if state is not None:
params["state"] = state

# Provide compatibility with current scope from the endpoint
_scope = scope or self.base_scopes
if _scope is not None:
params["scope"] = " ".join(_scope)

if code_challenge is not None:
params["code_challenge"] = code_challenge

if code_challenge_method is not None:
params["code_challenge_method"] = code_challenge_method

if extras_params is not None:
params = {**params, **extras_params} # type: ignore

return f"{self.authorize_endpoint}?{urlencode(params)}"
  • 响应类型为 code,表示获取 授权码
  • redirect_uri 参数表示:用户授权完成后,第三方平台应该把用户重定向到哪个页面
  • code_challengecode_challenge_method 参数用于 PKCE(Proof Key for Code Exchange)扩展,目的是防止授权码被拦截夺取
  • extras_params 参数用来提供额外的参数,这些参数针对不同的第三方平台可能会有所不同
  • 最终这些参数都会作为授权 URL 中的查询参数,并根据 authorize_endpoint 拼接成最终的授权 URL

BaseOAuth2 中的 get_access_token 方法则根据授权码(authorization code)获取 Access Token:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
async def get_access_token(
self, code: str, redirect_uri: str, code_verifier: Optional[str] = None
) -> OAuth2Token:
async with self.get_httpx_client() as client:
data = {
"grant_type": "authorization_code",
"code": code,
"redirect_uri": redirect_uri,
}

if code_verifier:
data["code_verifier"] = code_verifier

request, auth = self.build_request(
client,
"POST",
self.access_token_endpoint,
auth_method=self.token_endpoint_auth_method,
data=data,
)
response = await self.send_request(
client, request, auth, exc_class=GetAccessTokenError
)
data = self.get_json(response, exc_class=GetAccessTokenError)
return OAuth2Token(data)

def build_request(
self,
client: httpx.AsyncClient,
method: str,
url: str,
*,
auth_method: Union[OAuth2ClientAuthMethod, None] = None,
data: Union[Mapping[str, Any], None] = None,
) -> tuple[httpx.Request, Union[httpx.Auth, None]]:
if data is not None:
data = {
**data,
**(
{
"client_id": self.client_id,
"client_secret": self.client_secret,
}
if auth_method == "client_secret_post"
else {}
),
}

request = client.build_request(
method,
url,
data=data,
headers=self.request_headers,
)

auth = None
if auth_method == "client_secret_basic":
auth = httpx.BasicAuth(self.client_id, self.client_secret)

return request, auth

async def send_request(
self,
client: httpx.AsyncClient,
request: httpx.Request,
auth: Union[httpx.Auth, None],
*,
exc_class: type[OAuth2RequestError],
) -> httpx.Response:
try:
response = await client.send(request, auth=auth)
response.raise_for_status()
except httpx.HTTPStatusError as e:
raise exc_class(str(e), e.response) from e
# 表示 任何与 HTTP 请求/响应相关的错误,包括网络问题、超时、无效 URL、SSL 错误等。
except httpx.HTTPError as e:
raise exc_class(str(e)) from e

return response
  • 构造的请求中,包含 grant_typeauthorization_code,表示是使用授权码来获取 Access Token。同时传入 code 和 redirect_uri 参数
  • build_request 根据 auth_method 的不同,实现不同的认证流程:
    • client_secret_post:直接把 client_idclient_secret 塞进 POST 的 Body 里(编码为表单数据格式)
    • client_secret_basic:基于 HTTP Basic Auth 认证方法,将 client_idclient_secret 构造成 Authorization: Basic <base64> 形式的认证 Header
  • 通过传入 exc_class,它实现了异常的精准定位。如果这一步出错了,抛出的一定是 GetAccessTokenError
  • send_request 负责发送请求,并处理可能的异常流畅。它使用 raise_for_status 来检查响应状态码,如果表示错误(4xx 或 5xx),它会抛出一个异常。
    • 对于 httpx.HTTPStatusError,此时表示应用层的错误,此时是可以获取响应信息的
    • 对于 httpx.HTTPError,表示任何与 HTTP 请求/响应相关的错误,包括网络问题、超时等,此时可能无法获取到响应信息
  • 将获取到的 json 格式的响应转换为 OAuth2Token 对象,其实它就是一个字典,只不过包含了 过期时间 的处理:
1
2
3
4
5
6
7
8
9
10
11
12
13
class OAuth2Token(dict[str, Any]):
def __init__(self, token_dict: dict[str, Any]):
# 设置过期时间
if "expires_at" in token_dict:
token_dict["expires_at"] = int(token_dict["expires_at"])
elif "expires_in" in token_dict:
token_dict["expires_at"] = int(time.time()) + int(token_dict["expires_in"])
super().__init__(token_dict)

def is_expired(self) -> bool:
if "expires_at" not in self:
return False
return time.time() > self["expires_at"]

OAuth2 支持 Refresh Token 机制,应用可以在不需要用户重新授权的情况下获取新的 Access Token,此时是通过 Refresh Token 来获取新的 Access Token:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
async def refresh_token(self, refresh_token: str) -> OAuth2Token:
if self.refresh_token_endpoint is None:
raise RefreshTokenNotSupportedError()

async with self.get_httpx_client() as client:
request, auth = self.build_request(
client,
"POST",
self.refresh_token_endpoint,
auth_method=self.token_endpoint_auth_method,
data={
"grant_type": "refresh_token",
"refresh_token": refresh_token,
},
)
response = await self.send_request(
client, request, auth, exc_class=RefreshTokenError
)
data = self.get_json(response, exc_class=RefreshTokenError)
return OAuth2Token(data)
  • 刷新令牌的请求必须包含 grant_type="refresh_token",表示是使用刷新令牌来获取新的 Access Token

BaseOAuth2 类还提供了 revoke_token() 来撤销 Access Token 或 Refresh Token:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
async def revoke_token(
self, token: str, token_type_hint: Optional[str] = None
) -> None:
if self.revoke_token_endpoint is None:
raise RevokeTokenNotSupportedError()

async with self.get_httpx_client() as client:
data = {"token": token}

if token_type_hint is not None:
data["token_type_hint"] = token_type_hint

request, auth = self.build_request(
client,
"POST",
self.revoke_token_endpoint,
auth_method=self.token_endpoint_auth_method,
data=data,
)
await self.send_request(client, request, auth, exc_class=RevokeTokenError)

return None

最后,BaseOAuth2 类还定义了两个接口,用于根据 Access Token 获取用户信息,但这两个接口的实现需要由子类来完成:

1
2
3
4
5
async def get_profile(self, token: str) -> dict[str, Any]:
raise NotImplementedError()

async def get_id_email(self, token: str) -> tuple[str, Optional[str]]:
raise NotImplementedError()

BaseOAuth2 类是一个泛型类,通过类型变量 T 来表示构建授权 URL 时额外的参数类型,一般该类型就是一个 dict,因此为了便于使用,库代码直接定义了一个具体的 OAuth2 类型,库的用户直接使用该类即可:

1
OAuth2 = BaseOAuth2[dict[str, Any]]

以上就分析了 httpx-oauth 库的核心类 BaseOAuth2 的实现细节。可以看到,它的主要功能就是通过 httpx 这个 HTTP 客户端库来封装出 OAuth 认证流程的核心步骤,例如 构造授权 URL获取 Token刷新 Token撤销 Token 等。

GoogleOAuth2

上文说过,httpx-oauth 库针对不同的 OAuth 2.0 服务提供商预定义了一些具体的 OAuth2 Client 实现,这些 client 实现都保存在 clients 子目录中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
.
├── discord.py
├── facebook.py
├── franceconnect.py
├── github.py
├── google.py
├── __init__.py
├── kakao.py
├── linkedin.py
├── microsoft.py
├── naver.py
├── okta.py
├── openid.py
├── reddit.py
└── shopify.py

我们以 GoogleOAuth2 的实现为例,分析如何基于 BaseOAuth2 基类定义一个具体的 OAuth2 Client 实现类:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class GoogleOAuth2AuthorizeParams(TypedDict, total=False):
access_type: Literal["online", "offline"]
include_granted_scopes: bool
login_hint: str
prompt: Literal["none", "consent", "select_account"]


class GoogleOAuth2(BaseOAuth2[GoogleOAuth2AuthorizeParams]):
"""OAuth2 client for Google."""

display_name = "Google"
logo_svg = LOGO_SVG

def __init__(
self,
client_id: str,
client_secret: str,
scopes: Optional[list[str]] = BASE_SCOPES,
name: str = "google",
):
super().__init__(
client_id,
client_secret,
AUTHORIZE_ENDPOINT,
ACCESS_TOKEN_ENDPOINT,
ACCESS_TOKEN_ENDPOINT,
REVOKE_TOKEN_ENDPOINT,
name=name,
base_scopes=scopes,
token_endpoint_auth_method="client_secret_post",
revocation_endpoint_auth_method="client_secret_post",
)

async def get_profile(self, token: str) -> dict[str, Any]:
async with self.get_httpx_client() as client:
response = await client.get(
PROFILE_ENDPOINT,
params={"personFields": "emailAddresses"},
headers={**self.request_headers, "Authorization": f"Bearer {token}"},
)

if response.status_code >= 400:
raise GetProfileError(response=response)

return cast(dict[str, Any], response.json())

async def get_id_email(self, token: str) -> tuple[str, Optional[str]]:
try:
profile = await self.get_profile(token)
except GetProfileError as e:
raise GetIdEmailError(response=e.response) from e

user_id = profile["resourceName"]
user_email = next(
email["value"]
for email in profile["emailAddresses"]
if email["metadata"]["primary"]
)

return user_id, user_email
  • GoogleOAuth2AuthorizeParams 是一个 TypedDict,用于定义使用 Google OAuth 服务时构造授权 URL 所需要的额外参数
  • Google OAuth 服务的各个 Endpoint 都是固定的,因此 GoogleOAuth2.__init__() 方法中不在需要传入这些参数
  • get_profile 方法使用获取到的 Token 来请求 Google 的 profile endpoint,用于获取用户信息
  • get_id_email 方法则进一步从获取到的用户 profile 信息中提取到 user_iduser_email

OAuth2AuthorizeCallback

httpx_oauth/integrations/fastapi.py 中还定义了一个 OAuth2AuthorizeCallback 类,用于在 FastAPI 框架中处理 OAuth 授权回调流程。

上文说过,在构造 授权 URL 时,需要传入一个 redirect_uri 参数,这个参数用于在授权结束后将用户重定向到这个 URL。也就是说,用户授权结束后会跳转回你的应用,你需要处理这个请求:从这个请求中提取 Code,然后使用这个 Code 来获取 Access Token。而 OAuth2AuthorizeCallback 类则用于协助你的 FastAPI 应用处理这个回调请求。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class OAuth2AuthorizeCallback:
client: BaseOAuth2
route_name: Optional[str]
redirect_url: Optional[str]

def __init__(
self,
client: BaseOAuth2,
route_name: Optional[str] = None,
redirect_url: Optional[str] = None,
):
assert (route_name is not None and redirect_url is None) or (
route_name is None and redirect_url is not None
), "You should either set route_name or redirect_url"
self.client = client
self.route_name = route_name
self.redirect_url = redirect_url

async def __call__(
self,
request: Request,
code: Optional[str] = None,
code_verifier: Optional[str] = None,
state: Optional[str] = None,
error: Optional[str] = None,
) -> tuple[OAuth2Token, Optional[str]]:
if code is None or error is not None:
raise OAuth2AuthorizeCallbackError(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error if error is not None else None,
)

if self.route_name:
redirect_url = str(request.url_for(self.route_name))
elif self.redirect_url:
redirect_url = self.redirect_url

try:
access_token = await self.client.get_access_token(
code, redirect_url, code_verifier
)
except GetAccessTokenError as e:
raise OAuth2AuthorizeCallbackError(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=e.message,
response=e.response,
) from e

return access_token, state
  • 该类实现了 __call__ 方法,在 FastAPI 中,这允许你像使用函数一样使用这个类的实例作为依赖项:Depends(oauth2_authorize_callback)。
  • 当这个类被 Depends 调用时,FastAPI 会自动审视 __call__ 的签名。它发现有 code、state、error 这些参数,就会自动从当前请求 URL 的查询参数中提取它们
  • 为了获取 redirect_url,需要提供一个 route_name 或者直接传入 redirect_url 参数。如果传递的是 route_name,会通过 request.url_for(route_name) 自动获取完整的 URL
  • 之后通过 BaseOAuth2.get_access_token 方法来使用 code 获取 Access Token,并返回对应的 tokenstate

借助这个类,你可以在 FastAPI 应用中非常方便地处理 OAuth 授权后的回调流程:

1
2
3
4
5
6
7
8
9
10
11
12
from fastapi import FastAPI, Depends
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import OAuth2

client = OAuth2("CLIENT_ID", "CLIENT_SECRET", "AUTHORIZE_ENDPOINT", "ACCESS_TOKEN_ENDPOINT")
oauth2_authorize_callback = OAuth2AuthorizeCallback(client, "oauth-callback")
app = FastAPI()

@app.get("/oauth-callback", name="oauth-callback")
async def oauth_callback(access_token_state=Depends(oauth2_authorize_callback)):
token, state = access_token_state
# Do something useful
  • 由于 oauth2_authorize_callback 是个可调用对象(OAuth2AuthorizeCallback 实现了 __call__ 方法),所以可以直接作为 Depends 依赖项
  • FastAPI 会自动从当前请求的 URL 查询参数(Query Parameters)中获取 codestate 参数,并开始执行 OAuth2AuthorizeCallback.__call__ 中的逻辑:根据授权码申请 Access Token,并返回 tokenstate
  • 因此 oauth_callback 的函数体开始执行时,就可以直接使用 tokenstate 来进行后续操作了

小结

以上就基本分析完了 httpx-oauth 库的核心代码逻辑,主要是通过学习该库的代码,来对 OAuth2 授权认证流程有个直观的认识。

Reference