Netty自定义协议

2019-07-30  本文已影响0人  横渡

将消息定义结构为消息头和消息体两部分,消息头中存储消息的长度。netty读取消息头后,就能知道消息体的长度了。

自定义协议

协议开始标志 长度 数据
  1. 协议开始标志head_data,为int类型的数据,16进制表示为0X76;
  2. 传输数据的长度contentLength,int类型;
  3. 要传输的数据。

自定义协议类的封装:

package learn.netty.protocal;

/**
 * @author stone
 * @date 2019/7/29 17:38
 */
public class ConstantValue {
    /**
     * 自定义协议,报文开始标志
     */
    public static final int HEAD_DATA = 0X76;
}

package learn.netty.protocal;

import java.io.UnsupportedEncodingException;

/**
 * @author stone
 * @date 2019/7/29 17:39
 */
public class MyLsProtocol {
    /**
     * 消息头标志
     */
    private int header = ConstantValue.HEAD_DATA;

    /**
     * 消息长度
     */
    private int contentLength;

    /**
     * 消息内容
     */
    private byte[] content;

    public MyLsProtocol(int contentLength, byte[] content) {
        this.contentLength = contentLength;
        this.content = content;
    }

    public int getContentLength() {
        return contentLength;
    }

    public void setContentLength(int contentLength) {
        this.contentLength = contentLength;
    }

    public byte[] getContent() {
        return content;
    }

    public void setContent(byte[] content) {
        this.content = content;
    }

    public int getHeader() {
        return header;
    }

    @Override
    public String toString() {
        try {
            return "MyLsProtocol{" +
                    "header=" + header +
                    ", contentLength=" + contentLength +
                    ", content=" + new String(content, "utf-8") +
                    '}';
        } catch (UnsupportedEncodingException e) {
            e.printStackTrace();
            return "";
        }
    }
}

自定义协议的编码器:

package learn.netty.protocal;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToByteEncoder;

/**
 * 自定义协议编码器(对象转为字节)
 *
 * @author stone
 * @date 2019/7/29 17:41
 */
public class LsEncoder extends MessageToByteEncoder<MyLsProtocol> {
    @Override
    public void encode(ChannelHandlerContext ctx, MyLsProtocol msg, ByteBuf out) throws Exception {
        out.writeInt(msg.getHeader());
        out.writeInt(msg.getContentLength());
        out.writeBytes(msg.getContent());
    }
}

自定义协议的解码器:

package learn.netty.protocal;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;

import java.util.List;

/**
 * 自定义协议解码器 (字节转对象)
 *
 * @author stone
 * @date 2019/7/29 17:43
 */
public class LsDecoder extends ByteToMessageDecoder {
    /**
     * 报文开始的标志 header是int类型,占4个字节
     * 表示报文长度的contentLength是int类型,占4个字节
     */
    public final int BASE_LENGTH = 8;

    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
        // 可读长度必须大于基本长度
        if (in.readableBytes() >= BASE_LENGTH) {
            // 防止socket字节流攻击
            // 防止客户端传来的数据过大(太大的数据是不合理的)
            if (in.readableBytes() > 2048) {
                in.skipBytes(in.readableBytes());
            }

            // 记录包头开始的index
            int beginReader;
            while (true) {
                // 获取包头开始的index
                beginReader = in.readerIndex();
                // 标记包头开始的index
                in.markReaderIndex();
                // 读到协议的开始标志,结束while循环
                if (in.readInt() == ConstantValue.HEAD_DATA) {
                    break;
                }

                // 未读到包头,跳过一个字节
                // 每次跳过一个字节后,再去读取包头信息的开始标记
                in.resetReaderIndex();
                in.readByte();

                // 当跳过一个字节后,数据包的长度又变的不满足,此时应该结束,等待后边数据流的到达
                if (in.readableBytes() < BASE_LENGTH) {
                    return;
                }
            }

            // 代码到这里,说明已经读到了报文标志

            // 消息长度
            int length = in.readInt();
            // 判断请求数据包是否到齐
            if (in.readableBytes() < length) { // 数据不齐,回退读指针
                // 还原读指针
                in.readerIndex(beginReader);
                return;
            }

            // 至此,读到一条完整报文
            byte[] data = new byte[length];
            in.readBytes(data);
            MyLsProtocol protocol = new MyLsProtocol(data.length, data);
            out.add(protocol);
        }
    }
}

服务端实现:

package learn.netty.protocal;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;

/**
 * @author stone
 * @date 2019/7/30 9:30
 */
public class ProtocolServer {
    public ProtocolServer() {
    }

    public void bind(int port) throws Exception {
        // 配置IO线程组
        EventLoopGroup bossGroup = new NioEventLoopGroup();
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        try {
            // 服务器辅助启动类配置
            ServerBootstrap b = new ServerBootstrap();
            b.group(bossGroup, workerGroup)
                    .channel(NioServerSocketChannel.class)
                    .handler(new LoggingHandler(LogLevel.INFO))
                    .childHandler(new ChildChannelHandler())
                    .option(ChannelOption.SO_BACKLOG, 1024) // 设置tcp缓冲区
                    .option(ChannelOption.SO_KEEPALIVE, true);
            // 绑定端口,同步等待绑定成功
            ChannelFuture f = b.bind(port).sync();
            // 等待服务端监听端口关闭
            f.channel().closeFuture().sync();
        } finally {
            bossGroup.shutdownGracefully();
            workerGroup.shutdownGracefully();
        }
    }

// 服务端加入的协议编码/解码器
    public class ChildChannelHandler extends ChannelInitializer<SocketChannel> {
        @Override
        protected void initChannel(SocketChannel ch) throws Exception {
            // 添加自定义协议编码工具
            ch.pipeline().addLast(new LsEncoder());
            ch.pipeline().addLast(new LsDecoder());
            // 处理网络IO
            ch.pipeline().addLast(new ProtocolServerHandler());
        }
    }

    public static void main(String[] args) throws Exception {
        new ProtocolServer().bind(9999);
    }
}

服务端handler实现:

package learn.netty.protocal;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;

/**
 * @author stone
 * @date 2019/7/30 9:33
 */
public class ProtocolServerHandler extends ChannelInboundHandlerAdapter {

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        // 用于获取客户端发来的数据信息
        MyLsProtocol body = (MyLsProtocol) msg;
        System.out.println("Server接收到的客户端信息:" + body.toString());

        // 写数据给客户端
        String str = "Hi I am Server ...";
        MyLsProtocol response = new MyLsProtocol(str.getBytes().length, str.getBytes());
        // 当服务端完成写操作后,关闭与客户端的连接
        ctx.writeAndFlush(response);

        // 有写操作时,不需要手动释放msg的引用; 当只有读操作时,才需要手动释放msg的引用
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        System.out.println("Server exceptionCaught");
        cause.printStackTrace();
//        if (ctx.channel().isActive()) {
//            ctx.channel().writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
//        }
        ctx.close();
    }
}

客户端实现:

package learn.netty.protocal;

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;

/**
 * @author stone
 * @date 2019/7/30 9:45
 */
public class ProtocolClient {

    public void connect(int port, String host) throws Exception {
        // 配置客户端NIO线程组
        EventLoopGroup group = new NioEventLoopGroup();
        try {
            // 配置启动辅助类
            Bootstrap b = new Bootstrap();
            b.group(group)
                    .channel(NioSocketChannel.class)
                    .option(ChannelOption.TCP_NODELAY, true)
                    .handler(new MyChannelHandler());
            // 异步连接服务器,同步等待连接成功
            ChannelFuture f = b.connect(host, port).sync();
            // 等待连接关闭
            f.channel().closeFuture().sync();
        } finally {
            group.shutdownGracefully();
        }
    }

    public static void main(String[] args) throws Exception {
        new ProtocolClient().connect(9999, "127.0.0.1");
    }

// 客户端加入的协议编码/解码器
    public class MyChannelHandler extends ChannelInitializer<SocketChannel> {

        @Override
        protected void initChannel(SocketChannel ch) throws Exception {
            ch.pipeline().addLast(new LsEncoder());
            ch.pipeline().addLast(new LsDecoder());
            // 处理网络IO
            ch.pipeline().addLast(new ProtocolClientHandler());
        }
    }
}

客户端handler实现:

package learn.netty.protocal;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;

/**
 * 客户端Handler
 * @author
 */
public class ProtocolClientHandler extends ChannelInboundHandlerAdapter {

    /**
     * 客户端一旦与服务端建立好连接,就会触发该方法
     */
    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        // 发送消息
        String data = "I am client ...";
        // 获取要发送信息的字节数组
        byte[] content = data.getBytes();
        // 要发送信息的长度
        int contentLength = content.length;

        MyLsProtocol protocol = new MyLsProtocol(contentLength, content);
        for (int i = 0; i < 100; i++) {
            ctx.writeAndFlush(protocol);
        }
    }

    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        try {
            // 用于获取客户端发来的数据信息
            MyLsProtocol body = (MyLsProtocol) msg;
            System.out.println("Client接收的客户端的信息:" + body.toString());
        } finally {
            ReferenceCountUtil.release(msg);
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        ctx.close();
    }
}

参考博客:
http://www.cnblogs.com/whthomas/p/netty-custom-protocol.html
http://www.cnblogs.com/fanguangdexiaoyuer/p/6131042.html

上一篇下一篇

猜你喜欢

热点阅读