添加依赖
[dependencies]
tokio = { version = "1.38.0", features = ["macros", "rt"] }
futures = "0.3.30"
tower-http = { version = "0.5.2", features = ["cors"] }
axum = { version = "0.7.5", features = ["ws"] }
定义Message
#[derive(Debug, Clone)]
pub enum MsgData {
Add(String),
Remove(String),
Update(String),
Refresh(String),
}
定义AppState
#[derive(Debug)]
pub struct AppState {
mapping: DashMap<String, String>,
tx: Sender<MsgData>,
}
impl Default for AppState {
fn default() -> Self {
let (tx, _) = broadcast::channel::<MsgData>(64);
Self {
mapping: Default::default(),
tx
}
}
}
Axum Server
use std::{net::SocketAddr, sync::Arc};
use axum::{
http::Method,
routing::get,
Router
};
use dashmap::DashMap;
use tokio::{net::TcpListener, sync::broadcast::{self, Sender}};
use tower_http::cors::{Any, CorsLayer};
pub async fn app(port: u16) {
let app_state = Arc::new(AppState::default());
let app = Router::new()
.route("/ws", get(ws_handler))
.route("/client", get(client_handler))
.with_state(app_state)
.layer(CorsLayer::new().allow_origin(Any).allow_methods(vec![
Method::GET,
Method::POST,
Method::PUT,
Method::PATCH,
Method::DELETE,
]));
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let listener = TcpListener::bind(addr)
.await
.unwrap();
println!("Backend is listening on http://{}", addr);
axum::serve(listener, app)
.await
.unwrap();
}
WS Client Handler
use axum::{extract::State, response::IntoResponse};
async fn client_handler(State(state): State<Arc<AppState>>) -> impl IntoResponse {
let sender = &state.tx;
match sender.send(MsgData::Add("abc".into())) {
Ok(_) => println!("Send Message Success!"),
Err(msg) => println!("Send Message Error: {:?}", msg),
}
"Hello!"
}
WS Server Handler
use axum::extract::ws::{Message, WebSocket};
use axum::response::IntoResponse;
use axum::extract::{State, WebSocketUpgrade};
use futures::{SinkExt, StreamExt};
async fn ws_handler(ws: WebSocketUpgrade, State(state): State<Arc<AppState>>) -> impl IntoResponse {
ws.on_upgrade(|socket| handle_socket(socket, state))
}
async fn handle_socket(socket: WebSocket, state: Arc<AppState>) {
let (mut sender, mut receiver) = socket.split();
let mut recv_task = tokio::spawn(async move {
while let Some(Ok(Message::Text(data))) = receiver.next().await {
println!("this example does not read any messages, but got: {data}");
}
});
let mut rx = state.tx.subscribe();
let mut send_task = tokio::spawn(async move {
while let Ok(data) = rx.recv().await {
println!("Receive Data: {:?}", data);
if sender.send(Message::Text("Hello World!".into())).await.is_err() {
println!("send warnning!");
}
}
});
tokio::select! {
_v1 = &mut recv_task => send_task.abort(),
_v2 = &mut send_task => recv_task.abort(),
}
println!("handle websocket over");
}