# websocket--3.5.netty方式03

# core

package com.zs.websocket.netty.core;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Configuration;

@Slf4j
@Configuration
public class NettyWebSocketServer implements InitializingBean, DisposableBean {
    @Autowired
    private WebSocketProperties webSocketProperties;

    @Autowired
    private NioWebSocketChannelInitializer webSocketChannelInitializer;

    private EventLoopGroup bossGroup;
    private EventLoopGroup workGroup;
    private ChannelFuture channelFuture;

    /**
     * 初始化时机早于@PostConstruct或者xml的init-method
     *
     * @throws Exception
     */
    @Override
    public void afterPropertiesSet() throws Exception {
        try {
            bossGroup = new NioEventLoopGroup(webSocketProperties.getBoss());
            workGroup = new NioEventLoopGroup(webSocketProperties.getWork());

            ServerBootstrap serverBootstrap = new ServerBootstrap();
            serverBootstrap
                    //半连接 + 全连接 > SO_BACKLOG 时会被拒绝
                    .option(ChannelOption.SO_BACKLOG, 1024)
                    //用于剔除失效连接
                    .option(ChannelOption.SO_KEEPALIVE, true)
                    //分工 boss 只负责 ServerSocketChannel 上accept事件 worker 只负责 socketChannel 上的读写
                    //可以理解为:线程池 + Selector
                    .group(bossGroup, workGroup)
                    //选择服务 Scoket 实现类,其中 NioServerSocketChannel 表示基于 NIO 的服务器端实现
                    .channel(NioServerSocketChannel.class)
                    //配置启动端口
                    .localAddress(webSocketProperties.getPort())
                    // childHandler,是接下来添加的处理器都是给 SocketChannel 用的,而不是给 ServerSocketChannel
                    .childHandler(webSocketChannelInitializer);
            //启动服务器,监听端口,阻塞直到启动成功
            channelFuture = serverBootstrap.bind().sync();
        } finally {
            if (channelFuture != null && channelFuture.isSuccess()) {
                log.info("netty服务启动端口: {} (websocket) 连接地址 '{}'", webSocketProperties.getPort(), webSocketProperties.getPath());
            } else {
                log.error("netty服务启动失败");
                if (bossGroup != null)
                    bossGroup.shutdownGracefully().sync();
                if (workGroup != null)
                    workGroup.shutdownGracefully().sync();
            }
        }
    }

    @Override
    public void destroy() throws Exception {
        log.info("Shutting down Netty server...");
        if (bossGroup != null)
            bossGroup.shutdownGracefully().sync();
        if (workGroup != null)
            workGroup.shutdownGracefully().sync();
        if (channelFuture != null)
            channelFuture.channel().closeFuture().syncUninterruptibly();
        log.info("Netty server shutdown.");
    }

}
package com.zs.websocket.netty.core;

import com.zs.websocket.netty.core.WebSocketProperties;
import com.zs.websocket.netty.handler.NioWebSocketHandler;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.stream.ChunkedWriteHandler;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

@Component
public class NioWebSocketChannelInitializer extends ChannelInitializer<SocketChannel> {

    @Autowired
    private WebSocketProperties webSocketProperties;
    @Autowired
    private NioWebSocketHandler nioWebSocketHandler;

    @Override
    protected void initChannel(SocketChannel socketChannel) {
        socketChannel.pipeline()
                // 因为使用http协议,所以需要使用http的编码器,解码器
                .addLast(new HttpServerCodec())
                // 以块方式写,添加 chunkedWriter 处理器
                .addLast(new ChunkedWriteHandler())
                /**
                 * 说明:
                 *  1. http数据在传输过程中是分段的,HttpObjectAggregator可以把多个段聚合起来;
                 *  2. 这就是为什么当浏览器发送大量数据时,就会发出多次 http请求的原因
                 */
                .addLast(new HttpObjectAggregator(8192))
                // 自定义handler ,处理业务逻辑
                .addLast(nioWebSocketHandler)
                /**
                 * 说明:
                 *  1. 对于 WebSocket,它的数据是以帧frame 的形式传递的;
                 *  2. 可以看到 WebSocketFrame 下面有6个子类
                 *  3. 浏览器发送请求时: ws://localhost:7000/hello 表示请求的uri
                 *  4. WebSocketServerProtocolHandler 核心功能是把 http协议升级为 ws 协议,保持长连接;
                 *      是通过一个状态码 101 来切换的
                 */
                .addLast(new WebSocketServerProtocolHandler(webSocketProperties.getPath(), null, true, 65536));
    }
}
package com.zs.websocket.netty.core;

import java.util.HashMap;
import java.util.Map;

public class RequestUriUtils {

  /**
   * 将路径参数转换成Map对象,如果路径参数出现重复参数名,将以最后的参数值为准
   * @param uri 传入的携带参数的路径
   * @return
   */
  public static Map<String, String> getParams(String uri) {
    Map<String, String> params = new HashMap<>(10);

    int idx = uri.indexOf("?");
    if (idx != -1) {
      String[] paramsArr = uri.substring(idx + 1).split("&");

      for (String param : paramsArr) {
        idx = param.indexOf("=");
        params.put(param.substring(0, idx), param.substring(idx + 1));
      }
    }

    return params;
  }

  /**
   * 获取URI中参数以外部分路径
   * @param uri
   * @return
   */
  public static String getBasePath(String uri) {
    if (uri == null || uri.isEmpty())
      return null;

    int idx = uri.indexOf("?");
    if (idx == -1)
      return uri;

    return uri.substring(0, idx);
  }

}
package com.zs.websocket.netty.core;

import lombok.Getter;
import lombok.Setter;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;

@Getter
@Setter
@Component
@ConfigurationProperties(prefix = "chat.websocket")
public class WebSocketProperties {

  private Integer port = 8070; // 监听端口
  private String path = "/ws"; // 请求路径
  private Integer boss = 2; // bossGroup线程数
  private Integer work = 8; // workGroup线程数
}

# handler

package com.zs.websocket.netty.handler;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.util.AttributeKey;
import org.apache.commons.lang3.StringUtils;

import java.net.InetSocketAddress;

public class HttpHeadersHandler extends ChannelInboundHandlerAdapter {
    private AttributeKey<String> key = AttributeKey.valueOf("Id");

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        if (msg instanceof FullHttpRequest) {
            HttpHeaders headers = ((FullHttpRequest) msg).headers();
            String ip = headers.get("X-Real-IP");
            if (StringUtils.isEmpty(ip)) {//如果没经过nginx,就直接获取远端地址
                InetSocketAddress address = (InetSocketAddress) ctx.channel().remoteAddress();
                ip = address.getAddress().getHostAddress();
            }
            NettyUtil.setAttr(ctx.channel(), NettyUtil.IP, ip);
        }
        ctx.fireChannelRead(msg);
    }
}
package com.zs.websocket.netty.handler;

import io.netty.channel.Channel;
import io.netty.util.Attribute;
import io.netty.util.AttributeKey;

/**
 * Description: netty工具类
 */

public class NettyUtil {

    public static AttributeKey<String> IP = AttributeKey.valueOf("ip");

    public static <T> void setAttr(Channel channel, AttributeKey<T> attributeKey, T data) {
        Attribute<T> attr = channel.attr(attributeKey);
        attr.set(data);
    }

    public static <T> T getAttr(Channel channel, AttributeKey<T> ip) {
        return channel.attr(ip).get();
    }
}
package com.zs.websocket.netty.handler;

import cn.hutool.http.HttpUtil;
import com.zs.websocket.netty.core.RequestUriUtils;
import com.zs.websocket.netty.core.WebSocketProperties;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.websocketx.*;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
@ChannelHandler.Sharable
@Component
public class NioWebSocketHandler extends SimpleChannelInboundHandler<WebSocketFrame> {

    @Autowired
    private WebSocketProperties webSocketProperties;

    /**
     * 所有已连接的websocket连接列表和一些额外参数
     */
    private static final ConcurrentHashMap<Channel, String> ONLINE_WS_MAP = new ConcurrentHashMap<>();



    //======================================================================
    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        log.debug("客户端连接:{}", ctx.channel().id());
        ONLINE_WS_MAP.put(ctx.channel(), ctx.channel().id().asLongText());
        super.channelActive(ctx);
    }

    @Override
    public void channelInactive(ChannelHandlerContext ctx) throws Exception {
        log.debug("客户端断开连接:{}", ctx.channel().id());
        // 可能出现业务判断离线后再次触发 channelInactive
        log.warn("触发 channelInactive 掉线![{}]", ctx.channel().id());
        userOffLine(ctx);
        super.channelInactive(ctx);
    }
    // 客户端离线
    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
        userOffLine(ctx);
    }
    private void userOffLine(ChannelHandlerContext ctx) {
        log.warn("websocket 断开连接");
        ONLINE_WS_MAP.remove(ctx.channel());
        ctx.channel().close();
    }

    // 处理异常
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        log.warn("异常发生,异常消息 ={}", cause.getMessage());
        ctx.channel().close();
    }

    //======================================================================



    /**
     * 心跳检查
     *
     * @param ctx
     * @param evt
     * @throws Exception
     */
    @Override
    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
        if (evt instanceof IdleStateEvent) {
            IdleStateEvent idleStateEvent = (IdleStateEvent) evt;
            // 读空闲
            if (idleStateEvent.state() == IdleState.READER_IDLE) {
                // 关闭用户的连接
                userOffLine(ctx);
            }
        }
        super.userEventTriggered(ctx, evt);
    }


    @Override
    public void channelReadComplete(ChannelHandlerContext ctx) {
        ctx.channel().flush();
    }
    //======================================================================


    @Override
    protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) {
        //System.out.println("服务器收到消息:" + msg.text());
        // 根据请求数据类型进行分发处理
        if (frame instanceof PingWebSocketFrame) {
            pingWebSocketFrameHandler(ctx, (PingWebSocketFrame) frame);
        } else if (frame instanceof TextWebSocketFrame) {
            textWebSocketFrameHandler(ctx, (TextWebSocketFrame) frame);
        } else if (frame instanceof CloseWebSocketFrame) {
            closeWebSocketFrameHandler(ctx, (CloseWebSocketFrame) frame);
        }

    }

    /**
     * 处理客户端心跳包
     *
     * @param ctx
     * @param frame
     */
    private void pingWebSocketFrameHandler(ChannelHandlerContext ctx, PingWebSocketFrame frame) {
        ctx.channel().writeAndFlush(new PongWebSocketFrame(frame.content().retain()));
    }
    /**
     * 创建连接之后,客户端发送的消息都会在这里处理
     *
     * @param ctx
     * @param frame
     */
    private void textWebSocketFrameHandler(ChannelHandlerContext ctx, TextWebSocketFrame frame) {
        log.debug("服务端接收到消息:{}",frame.text());
        // 客户端发送过来的内容不进行业务处理,原样返回
        ctx.channel().writeAndFlush(frame.retain());
    }

    /**
     * 客户端发送断开请求处理
     *
     * @param ctx
     * @param frame
     */
    private void closeWebSocketFrameHandler(ChannelHandlerContext ctx, CloseWebSocketFrame frame) {
        ctx.close();
    }

    //======================================================================


    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        log.debug("客户端请求数据类型:{}", msg.getClass());
        if (msg instanceof FullHttpRequest) {
            fullHttpRequestHandler(ctx, (FullHttpRequest) msg);
        }
        super.channelRead(ctx, msg);
    }
    /**
     * 处理连接请求,客户端WebSocket发送握手包时会执行这一次请求
     *
     * @param ctx
     * @param request
     */
    private void fullHttpRequestHandler(ChannelHandlerContext ctx, FullHttpRequest request) {
        String uri = request.uri();
        HashMap<String, String> params = (HashMap<String, String>) HttpUtil.decodeParamMap(uri, StandardCharsets.UTF_8);
        log.debug("客户端请求参数:{}", params);
        // 判断请求路径是否跟配置中的一致
        if (webSocketProperties.getPath().equals(RequestUriUtils.getBasePath(uri))) {
            // 因为有可能携带了参数,导致客户端一直无法返回握手包,因此在校验通过后,重置请求路径
            request.setUri(webSocketProperties.getPath());
        }else {
            log.error("服务地址uri:{}, 客户端连接uri:{}", webSocketProperties.getPath(), RequestUriUtils.getBasePath(uri) + "请检查路径不一致");
            ctx.close();
        }
    }
}
更新时间: 2024年1月30日星期二下午2点36分