use std::{ collections::HashSet, sync::{Arc, Mutex}, }; use axum::{ Router, extract::{ State, WebSocketUpgrade, ws::{Message, Utf8Bytes, WebSocket}, }, response::{Html, IntoResponse}, routing::get, }; use futures_util::{sink::SinkExt, stream::StreamExt}; use tokio::sync::broadcast; struct AppState { user_set: Mutex>, tx: broadcast::Sender, } #[tokio::main] async fn main() { let user_set = Mutex::new(HashSet::new()); let (tx, _rx) = broadcast::channel(100); let app_state = Arc::new(AppState { user_set, tx }); let app = Router::new() .route("/", get(index)) .route("/websocket", get(websocket_handler)) .with_state(app_state); let listener = tokio::net::TcpListener::bind("localhost:4000") .await .unwrap(); axum::serve(listener, app).await.unwrap(); } async fn index() -> Html<&'static str> { Html(include_str!("../chat.html")) } async fn websocket_handler( ws: WebSocketUpgrade, State(state): State>, ) -> impl IntoResponse { ws.on_upgrade(|socket| websocket(socket, state)) } async fn websocket(socket: WebSocket, state: Arc) { let (mut sender, mut receiver) = socket.split(); let mut username = String::new(); // Loop over received messages. Ignore anything that isn't a `Message::Text()`. // The first text message should be the client's username. If it's available, // set the name and continue to main operation. If it's not, return an error // message and exit the function. while let Some(Ok(message)) = receiver.next().await { if let Message::Text(name) = message { check_username(&state, &mut username, name.as_str()); if !username.is_empty() { break; // exit if we have a name } else { let _ = sender .send(Message::Text(Utf8Bytes::from_static( "Username is already taken!", ))) .await; return; } } } let mut rx = state.tx.subscribe(); let _ = state.tx.send(format!("{username} joined the lobby!")); // Read messages broadcast through the server, write them to this socket. // If any error is returned, break the loop to terminate the task. We're // not dealing with them right now. let mut send_task = tokio::spawn(async move { while let Ok(msg) = rx.recv().await { if sender.send(Message::text(msg)).await.is_err() { break; } } }); let tx = state.tx.clone(); let name = username.clone(); let mut receive_task = tokio::spawn(async move { while let Some(Ok(Message::Text(msg))) = receiver.next().await { let _ = tx.send(format!("{name} -- {msg}")); } }); // If either task completes, abort the other. They loop infinitely as long // as the user is connected, so when one stops it means they should both // stop. tokio::select! { _ = &mut send_task => receive_task.abort(), _ = &mut receive_task => send_task.abort(), } // Remove the username from the set when the socket is closing down. state.user_set.lock().unwrap().remove(&username); } /// Sets the requested username into buffer `name_out` if it is currently unused in the lobby. /// /// Check for presence of `name` in `state.user_set`. If taken, the `name_out` out-parameter /// is left unchanged (which should be empty, signaling to the caller that the name is /// unavailable). /// /// If the name is available, it is added to `state.user_set` (thus making it unavailable /// going forward) and written into the `name_out` buffer for use by the caller (non-empty /// values signal that the name has been accepted for use). fn check_username(state: &AppState, name_out: &mut String, name: &str) { // TODO: Return a Result instead of using out-parameters. This isn't C, // we can do better. let mut user_set = state.user_set.lock().unwrap(); if !user_set.contains(name) { user_set.insert(name.to_owned()); name_out.push_str(name); } }