使用 Axum 与 Tower 中间件构建支持 OAuth 2.0 鉴权的通用二阶段提交协调器


多服务原子操作是个老生常谈的难题。在一个典型的微服务场景中,一个用户下单操作可能需要同时调用订单服务创建订单和钱包服务扣减余额。如果其中一个服务调用成功而另一个失败,系统就会进入数据不一致的状态。在真实项目中,这种不一致性是灾难性的,尤其是在金融或电商领域。

最初的构想是通过重试和补偿逻辑来处理,但这很快让业务代码变得臃肿不堪,充满了各种状态检查和修复逻辑。我们需要一个更通用的解决方案。自然而然,我想到了分布式事务协议。尽管 Saga、TCC 等模式更具弹性,但在某些要求强一致性且事务持续时间较短的场景下,两阶段提交(Two-Phase Commit, 2PC)因其协议简单、一致性保障强,仍然有一席之地。

问题在于,如何优雅地实现一个 2PC 协调器?它不能与业务逻辑紧密耦合。它必须是一个独立、可重用、并且至关重要的是,安全的基础设施组件。我们的技术栈是基于 Rust 和 Axum 构建的,追求高性能和高可靠性。因此,将 2PC 协调器实现为一个通用的 Axum 服务,并通过 Tower 中间件生态进行扩展,似乎是一条可行的路径。

此外,任何对资源进行修改的操作都必须经过严格的认证和授权。服务间的调用也不能例外。因此,这个协调器必须能与我们的 OAuth 2.0 身份认证体系集成,确保只有携带了特定权限(Scope)的请求才能发起一个分布式事务。

最终目标是:构建一个独立的、可通过 HTTP 调用的 2PC 协调器。它接收一个事务参与者列表和一个初始请求,负责协调所有参与者的 preparecommit/rollback 阶段。整个过程由一个具备特定 scope 的 OAuth 2.0 Bearer Token 保护。

技术选型与协议设计

  1. 框架: Axum。它基于 hypertokio,性能卓越,并且其 Tower 中间件架构提供了强大的可组合性。
  2. 安全性: OAuth 2.0。使用 Bearer Token 进行服务间认证。协调器需要扮演资源服务器的角色,对接收到的 Token 进行自省(Introspection)或验签,以验证其有效性和权限范围。
  3. 2PC 协议: 经典的请求-准备-提交模型。
    • 协调器 (Coordinator): 接收外部请求,向所有参与者发送 prepare 请求。如果所有参与者都返回成功,则向所有参与者发送 commit 请求;否则,发送 rollback 请求。
    • 参与者 (Participant): 暴露 /prepare, /commit, /rollback 三个接口。在 prepare 阶段锁定资源,在 commit 阶段应用更改,在 rollback 阶段释放资源。

为了让这个协调器通用化,我们需要定义清晰的 API 契约。

协调器 API (POST /transaction)

请求体:

{
  "participants": [
    {
      "id": "order_service",
      "endpoints": {
        "prepare": "http://order-service/prepare",
        "commit": "http://order-service/commit",
        "rollback": "http://order-service/rollback"
      }
    },
    {
      "id": "wallet_service",
      "endpoints": {
        "prepare": "http://wallet-service/prepare",
        "commit": "http://wallet-service/commit",
        "rollback": "http://wallet-service/rollback"
      }
    }
  ],
  "payload": {
    "user_id": "user-123",
    "order_id": "order-abc",
    "amount": 100
  }
}

协调器会把 payload 部分透传给每个参与者的 prepare 接口。

参与者 API

  • POST /prepare: 接收 payload,锁定资源,返回 200 OK 或错误码。
  • POST /commit: 确认事务,应用变更。
  • POST /rollback: 取消事务,释放资源。

核心实现:构建协调器

首先,我们需要管理协调器内部的事务状态。一个简单的状态机就能满足需求。

// src/coordinator/state.rs

use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum TransactionState {
    Init,
    Preparing,
    Prepared,
    Committing,
    RollingBack,
    Committed,
    Aborted,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Participant {
    pub id: String,
    pub endpoints: HashMap<String, String>,
}

#[derive(Debug, Clone, Serialize)]
pub struct Transaction {
    pub id: String,
    pub state: TransactionState,
    pub participants: Vec<Participant>,
    #[serde(skip_serializing)]
    pub payload: serde_json::Value,
}

// 在真实项目中,这里应该使用持久化存储,如数据库或日志文件
// 为了简化示例,我们使用内存中的 HashMap
pub type TransactionStore = Arc<Mutex<HashMap<String, Transaction>>>;

接着,是协调器的核心逻辑。它接收请求,启动状态机,并与参与者进行通信。

// src/coordinator/handler.rs

use axum::{extract::State, http::StatusCode, Json};
use serde::Deserialize;
use serde_json::Value;
use tracing::{error, info, instrument};
use uuid::Uuid;
use std::time::Duration;

use super::state::{Participant, Transaction, TransactionState, TransactionStore};

#[derive(Deserialize)]
pub struct TransactionRequest {
    participants: Vec<Participant>,
    payload: Value,
}

#[instrument(skip(store))]
pub async fn handle_transaction(
    State(store): State<TransactionStore>,
    Json(request): Json<TransactionRequest>,
) -> (StatusCode, Json<Value>) {
    let tx_id = Uuid::new_v4().to_string();
    let mut transaction = Transaction {
        id: tx_id.clone(),
        state: TransactionState::Init,
        participants: request.participants,
        payload: request.payload,
    };

    info!(tx_id = %tx_id, "Transaction initiated");
    
    // 1. 存入状态,为后续恢复做准备(虽然本例是内存)
    {
        let mut store_lock = store.lock().await;
        transaction.state = TransactionState::Preparing;
        store_lock.insert(tx_id.clone(), transaction.clone());
    }

    // 2. Prepare 阶段
    let prepare_futs = transaction
        .participants
        .iter()
        .map(|p| call_participant(p, "prepare", &transaction.payload));
    
    let prepare_results = futures::future::join_all(prepare_futs).await;

    let all_prepared = prepare_results.iter().all(|res| res.is_ok());

    if all_prepared {
        info!(tx_id = %tx_id, "All participants prepared successfully. Committing...");
        update_tx_state(&store, &tx_id, TransactionState::Committing).await;

        // 3a. Commit 阶段
        let commit_futs = transaction
            .participants
            .iter()
            .map(|p| call_participant(p, "commit", &Value::Null));
        
        // 在生产环境中,commit阶段的失败需要更复杂的重试或手动干预机制
        // 因为prepare成功后,资源已锁定,必须有一个最终结果
        futures::future::join_all(commit_futs).await;
        
        update_tx_state(&store, &tx_id, TransactionState::Committed).await;
        info!(tx_id = %tx_id, "Transaction committed.");
        
        let response = serde_json::json!({ "status": "committed", "transaction_id": tx_id });
        (StatusCode::OK, Json(response))

    } else {
        error!(tx_id = %tx_id, "One or more participants failed to prepare. Rolling back...");
        update_tx_state(&store, &tx_id, TransactionState::RollingBack).await;

        // 3b. Rollback 阶段
        let rollback_futs = transaction
            .participants
            .iter()
            .map(|p| call_participant(p, "rollback", &Value::Null));

        futures::future::join_all(rollback_futs).await;
        
        update_tx_state(&store, &tx_id, TransactionState::Aborted).await;
        error!(tx_id = %tx_id, "Transaction aborted.");

        let response = serde_json::json!({ "status": "aborted", "transaction_id": tx_id });
        (StatusCode::INTERNAL_SERVER_ERROR, Json(response))
    }
}


async fn call_participant(participant: &Participant, phase: &str, payload: &Value) -> Result<(), String> {
    let endpoint = participant.endpoints.get(phase)
        .ok_or_else(|| format!("Participant {} has no endpoint for phase {}", participant.id, phase))?;

    let client = reqwest::Client::new();
    
    info!("Calling {} on participant {}", phase, participant.id);

    // 这里的超时设置至关重要,防止某个参与者无响应导致整个事务阻塞
    let res = client.post(endpoint)
        .json(payload)
        .timeout(Duration::from_secs(5))
        .send()
        .await;

    match res {
        Ok(response) if response.status().is_success() => {
             info!("Participant {} responded OK for phase {}", participant.id, phase);
             Ok(())
        },
        Ok(response) => {
            let status = response.status();
            let body = response.text().await.unwrap_or_default();
            error!(
                "Participant {} failed for phase {}. Status: {}, Body: {}",
                participant.id, phase, status, body
            );
            Err(format!("Participant {} failed phase {}", participant.id, phase))
        },
        Err(e) => {
            error!(
                "Failed to call participant {} for phase {}: {}",
                participant.id, phase, e
            );
            Err(format!("Network error for participant {}", participant.id))
        }
    }
}


async fn update_tx_state(store: &TransactionStore, tx_id: &str, new_state: TransactionState) {
    let mut store_lock = store.lock().await;
    if let Some(tx) = store_lock.get_mut(tx_id) {
        tx.state = new_state;
    }
}

集成 OAuth 2.0 鉴权

这是将一个简单原型提升为生产级组件的关键。我们将实现一个 Axum 中间件,它会在 handle_transaction 之前运行。

这个中间件的职责是:

  1. Authorization 头中提取 Bearer Token。
  2. 调用 OAuth 2.0 认证服务器的 Introspection 端点来验证 Token。
  3. 检查返回的 Token 信息中是否包含我们要求的 scope,例如 transaction:execute
  4. 如果验证通过,则将请求传递给下一个处理器;否则,直接返回 401 Unauthorized403 Forbidden
// src/auth.rs

use axum::{
    async_trait,
    extract::{FromRequestParts, Request, State},
    http::{header, request::Parts, StatusCode},
    middleware::Next,
    response::{IntoResponse, Response},
    Json,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing::warn;

// 模拟的OAuth 2.0认证服务器配置
#[derive(Clone)]
pub struct AuthConfig {
    pub introspection_endpoint: String,
    pub required_scope: String,
}

#[derive(Deserialize, Debug)]
struct IntrospectionResponse {
    active: bool,
    scope: Option<String>,
}

pub async fn auth_middleware(
    State(auth_config): State<Arc<AuthConfig>>,
    mut request: Request,
    next: Next,
) -> Result<Response, StatusCode> {
    let token = request.headers()
        .get(header::AUTHORIZATION)
        .and_then(|value| value.to_str().ok())
        .and_then(|value| value.strip_prefix("Bearer "));

    let token = match token {
        Some(t) => t,
        None => {
            warn!("Missing authorization token");
            return Err(StatusCode::UNAUTHORIZED);
        }
    };

    let client = reqwest::Client::new();
    let params = [("token", token)];
    
    // 在真实世界中,你可能还需要向认证服务器提供客户端凭证
    let res = client
        .post(&auth_config.introspection_endpoint)
        .form(&params)
        .send()
        .await
        .map_err(|e| {
            error!("Failed to contact introspection endpoint: {}", e);
            StatusCode::INTERNAL_SERVER_ERROR
        })?;

    if !res.status().is_success() {
        warn!("Introspection endpoint returned non-success status: {}", res.status());
        return Err(StatusCode::UNAUTHORIZED);
    }

    let intro_res: IntrospectionResponse = res.json().await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

    if !intro_res.active {
        warn!("Token is not active");
        return Err(StatusCode::UNAUTHORIZED);
    }
    
    let scopes = intro_res.scope.unwrap_or_default();
    if !scopes.split_whitespace().any(|s| s == auth_config.required_scope) {
        warn!("Token is missing required scope: {}", auth_config.required_scope);
        return Err(StatusCode::FORBIDDEN);
    }
    
    // 可以将验证后的身份信息放入请求扩展中,供下游服务使用
    // request.extensions_mut().insert(...);

    Ok(next.run(request).await)
}

组装应用与模拟参与者

现在,我们将所有部分组装起来。

// src/main.rs

use axum::{
    middleware,
    routing::{post, get},
    Router,
};
use std::{collections::HashMap, net::SocketAddr, sync::Arc};
use tokio::sync::Mutex;
use tracing::{info, Level};
use tracing_subscriber::FmtSubscriber;

mod auth;
mod coordinator;

use auth::{auth_middleware, AuthConfig};
use coordinator::{handler::handle_transaction, state::{TransactionStore, Transaction}};

#[tokio::main]
async fn main() {
    let subscriber = FmtSubscriber::builder()
        .with_max_level(Level::INFO)
        .finish();
    tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");

    // 启动一个模拟的认证服务器
    tokio::spawn(run_mock_auth_server());
    // 启动一个模拟的参与者服务
    tokio::spawn(run_mock_participant_server());

    let store: TransactionStore = Arc::new(Mutex::new(HashMap::new()));
    let auth_config = Arc::new(AuthConfig {
        introspection_endpoint: "http://127.0.0.1:9001/introspect".to_string(),
        required_scope: "transaction:execute".to_string(),
    });

    let app = Router::new()
        .route("/transaction", post(handle_transaction))
        .route_layer(middleware::from_fn_with_state(
            auth_config.clone(),
            auth_middleware,
        ))
        // 添加一个状态查询接口,方便调试
        .route("/transactions/:id", get(get_transaction_status))
        .with_state(store);

    let addr = SocketAddr::from(([127, 0, 0, 1], 9000));
    info!("Coordinator listening on {}", addr);
    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
    axum::serve(listener, app).await.unwrap();
}


async fn get_transaction_status(
    axum::extract::Path(tx_id): axum::extract::Path<String>,
    State(store): State<TransactionStore>
) -> impl axum::response::IntoResponse {
    let store_lock = store.lock().await;
    if let Some(tx) = store_lock.get(&tx_id) {
        (StatusCode::OK, Json(tx.clone())).into_response()
    } else {
        StatusCode::NOT_FOUND.into_response()
    }
}


// --- 模拟服务 ---

// 模拟一个认证服务器,用于Token自省
async fn run_mock_auth_server() {
    let app = Router::new().route("/introspect", post(|body: String| async {
        // 简化逻辑:如果token是"valid_token_with_scope", 则认为是有效的
        if body.contains("valid_token_with_scope") {
            Json(serde_json::json!({
                "active": true,
                "scope": "read write transaction:execute"
            }))
        } else {
            Json(serde_json::json!({ "active": false }))
        }
    }));

    let addr = SocketAddr::from(([127, 0, 0, 1], 9001));
    info!("Mock Auth Server listening on {}", addr);
    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
    axum::serve(listener, app).await.unwrap();
}


// 模拟一个事务参与者
async fn run_mock_participant_server() {
    let app = Router::new()
        .route("/prepare", post(|| async { 
            info!("[Participant] Received prepare request. Locking resources.");
            StatusCode::OK 
        }))
        .route("/commit", post(|| async { 
            info!("[Participant] Received commit request. Applying changes.");
            StatusCode::OK
        }))
        .route("/rollback", post(|| async {
            info!("[Participant] Received rollback request. Releasing resources.");
            StatusCode::OK
        }));

    let addr = SocketAddr::from(([127, 0, 0, 1], 9002));
    info!("Mock Participant Server listening on {}", addr);
    let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
    axum::serve(listener, app).await.unwrap();
}

测试与验证

  1. 启动服务: cargo run

  2. 使用有效 Token 发起事务(成功场景):

curl -X POST http://127.0.0.1:9000/transaction \
-H "Content-Type: application/json" \
-H "Authorization: Bearer valid_token_with_scope" \
-d '{
  "participants": [
    {
      "id": "mock_service_1",
      "endpoints": {
        "prepare": "http://127.0.0.1:9002/prepare",
        "commit": "http://127.0.0.1:9002/commit",
        "rollback": "http://127.0.0.1:9002/rollback"
      }
    }
  ],
  "payload": {
    "some_data": "value"
  }
}'

预期输出: {"status":"committed","transaction_id":"..."}。同时,控制台日志会显示完整的 prepare -> commit 流程。

  1. 使用无效 Token(鉴权失败):
curl -v -X POST http://127.0.0.1:9000/transaction \
-H "Content-Type: application/json" \
-H "Authorization: Bearer invalid_token" \
-d '{}'

预期输出: HTTP/1.1 401 Unauthorized

  1. 模拟 Prepare 失败(回滚场景):

为了模拟这个,我们可以临时修改模拟参与者的 /prepare 接口,让它返回一个非 200 的状态码。例如,修改为 StatusCode::INTERNAL_SERVER_ERROR。然后重新运行并发送与第一步相同的请求。

预期输出: {"status":"aborted","transaction_id":"..."}。控制台日志会显示 prepare 失败,然后触发 rollback

sequenceDiagram
    participant Client
    participant Coordinator
    participant AuthServer
    participant ParticipantA
    participant ParticipantB

    Client->>+Coordinator: POST /transaction (Payload, Token)
    Coordinator->>+AuthServer: /introspect (Token)
    AuthServer-->>-Coordinator: {active: true, scope: "..."}
    Note right of Coordinator: Token and scope are valid.
    Coordinator->>+ParticipantA: POST /prepare (Payload)
    ParticipantA-->>-Coordinator: 200 OK
    Coordinator->>+ParticipantB: POST /prepare (Payload)
    ParticipantB-->>-Coordinator: 200 OK
    Note right of Coordinator: All participants prepared.
    Coordinator->>+ParticipantA: POST /commit
    ParticipantA-->>-Coordinator: 200 OK
    Coordinator->>+ParticipantB: POST /commit
    ParticipantB-->>-Coordinator: 200 OK
    Coordinator-->>-Client: 200 OK {status: "committed"}

局限性与未来展望

这个实现虽然验证了核心思路,但在生产环境中部署前,还有几个关键问题需要解决。

首先,协调器的单点故障和状态持久化。当前实现将所有事务状态保存在内存中,一旦协调器重启,所有进行中的事务信息都会丢失。这是一个致命缺陷。在真实项目中,必须将事务日志持久化到高可用的存储中,比如 etcd、ZooKeeper 或者一个支持事务的数据库(如 PostgreSQL)。协调器在重启后,需要读取日志来恢复中断的事务,决定是继续提交还是回滚。

其次,2PC 的固有缺陷。它是一个阻塞协议。在 prepare 成功后,commit 消息到达前,所有参与者必须锁定资源。如果协调器宕机,这些资源将被无限期锁定,直到协调器恢复。这会严重影响系统的可用性。对于需要长时间运行或者跨网络边界的事务,Saga 模式可能是更好的选择。

最后,参与者的幂等性。网络是不可靠的,协调器可能会重发 commitrollback 请求。参与者的接口必须设计成幂等的,确保重复执行不会产生副作用。

未来的优化路径可以包括:

  1. 实现协调器的高可用:通过 Raft 或 Paxos 协议将协调器做成集群,避免单点故障。
  2. 集成持久化事务日志:使用 sledRocksDB 等嵌入式数据库或外部数据库来存储事务状态。
  3. 提供更丰富的查询 API:允许查询所有进行中或失败的事务,方便运维人员手动介入。
  4. 探索 3PC:三阶段提交通过引入一个 canCommit 阶段来缓解 2PC 的阻塞问题,但协议也更复杂。

  目录