自定义实现grpc拦截器

2023-10-15  本文已影响0人  小胖学编程

使用;框架进行通信时,有时候需要对编写拦截器对请求或者响应对象进行拦截。如何实现拦截呢?

服务端

服务端拦截器如下图所示:

serverCall:是响应的回调接口,可以用于直接关闭请求;

一般拦截器返回的是next.startCall(serverCall, headers);但是如果想获取到请求对象或者响应对象,需要通过装饰器模式来进行增强,在增强的时候,可以做一些处理。

public class TestServerInterceptor implements ServerInterceptor {

    public static final Metadata.Key<Long> USER_ID_KEY = Metadata.Key.of("userId", ASCII_LONG_MARSHALLER);
    public static final Metadata.Key<String> OPT_NAME_KEY = Metadata.Key.of("userName", ASCII_STRING_MARSHALLER);
    public static final ThreadLocal<Long> USER_THREAD_LOCAL = new ThreadLocal<>();


    @Override
    public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> serverCall, Metadata headers,
                                                                 ServerCallHandler<ReqT, RespT> next) {
        /**
         * 处理请求的header,做一些特殊处理
         */
        //例如验证权限
        if (headers.get(USER_ID_KEY) == null) {
            //直接关闭请求
            serverCall.close(Status.UNAUTHENTICATED.withDescription("auth failed"), new Metadata());
        }
      //此处可以将header的值放入到ThreadLocal中 

        return new ForwardingServerCallListener.SimpleForwardingServerCallListener<ReqT>(next.startCall(
                /**
                 * 回调监听接口
                 */
                new ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT>(serverCall) {

                    /**
                     * 发送响应报文时候的拦截(可以打印响应报文)
                     * @param message response message.
                     */
                    @Override
                    public void sendMessage(RespT message) {
                        //可以做点什么(这里可以拿到接口的response)
                        log.info("test open sendMessage:{}", ObjectMapperUtils.toJSON(message));
                        super.sendMessage(message);
                        log.info("test close sendMessage");

                    }

                    /**
                     * 发送响应报文的拦截(可以设置响应header头)
                     * @param headers metadata to send prior to any response body.
                     */
                    @Override
                    public void sendHeaders(Metadata headers) {
                        log.info("test open sendHeaders");

                        Metadata.Key<Long> respXxx = Metadata.Key.of("respXxx", ASCII_LONG_MARSHALLER);
                        headers.put(respXxx, 1232L);
                        super.sendHeaders(headers);
                        log.info("test close sendHeaders");
                    }
                }, headers)) {

            /**
             * 打印请求报文
             */
            @Override
            public void onMessage(ReqT message) {
                //打印入参
                log.info("test open onMessage:{}", ObjectMapperUtils.toJSON(message));
                //参数处理,校验,若发现数据有误,则抛出异常
                log.info("test close onMessage");
            }


            /**
             * 代表本次请求正常结束
             */
            @Override
            public void onComplete() {
                log.info("test open onComplete()");
                //可以做点什么
                delegate().onComplete();
                log.info("test close onComplete()");

            }

            /**
             * 代表本次请求被取消掉,通常发生在服务端执行出现异常的情况会被调用。
             *
             * 例如请求超时,会执行到这个方法。
             */
            @Override
            public void onCancel() {
                log.info("test open onCancel()");
                delegate().onCancel();
                log.info("test close onCancel()");
            }

            /**
             * 贯穿整个请求的整个生命周期。
             */
            @Override
            public void onHalfClose() {
                log.info("test open onHalfClose()");
                log.info("test close onHalfClose()");
            }
        };
      //return next.startCall(serverCall, headers);

    }
}

执行的顺序:

open onMessage()
close onMessage()

open onHalfClose()
open 业务代码
open onComplete()
close onComplete()
close onHalfClose()

客户端

public class TestClientInterceptor implements ClientInterceptor {
    public static final Key<Long> USER_ID_KEY = Key.of("userId", ASCII_LONG_MARSHALLER);
    public static final Key<String> OPT_NAME_KEY = Key.of("userName", ASCII_STRING_MARSHALLER);

    @Override
    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method,
                                                               CallOptions callOptions, Channel next) {
        
        return new SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {


            @Override
            public void start(Listener<RespT> responseListener, Metadata headers) {
                /**
                 * 发送请求报文的时候,对请求的heade进行处理。
                 */
                log.info("test open injectHeadersFromScope()");
                headers.put(USER_ID_KEY, 1001L);
                log.info("test close injectHeadersFromScope()");


                //开始请求,入参即为响应报文的处理,
                super.start(new ForwardingClientCallListener.SimpleForwardingClientCallListener<RespT>(responseListener) {

                    /**
                     * 响应报文来了,进行处理
                     * @param message returned by the server
                     */
                    public void onMessage(RespT message) {
                        log.info("test open onMessage:{}", ObjectMapperUtils.toJSON(message));
                        super.onMessage(message);
                        log.info("test close onMessage ");
                    }

                    /**
                     * 响应报文来了,获取到响应的header头信息
                     * @param headers containing metadata sent by the server at the start of the response.
                     */
                    @Override
                    public void onHeaders(Metadata headers) {
                        Key<Long> respXxx = Key.of("respXxx", ASCII_LONG_MARSHALLER);
                        log.info("test open onHeaders:{}", headers.get(respXxx));
                        super.onHeaders(headers);
                        log.info("test close onHeaders");
                    }

                    /**
                     *
                     * @param status the result of the remote call.   错误码
                     * @param trailers metadata provided at call completion.
                     */
                    @Override
                    public void onClose(Status status, Metadata trailers) {
                        log.info("test open onClose() :resp {}", ObjectMapperUtils.toJSON(status));

                        super.onClose(status, trailers);
                        log.info("test close onClose()");
                    }

                    @Override
                    public void onReady() {
                        log.info("test open onReady()");
                        super.onReady();
                        log.info("test close onReady()");
                    }
                }, headers);
                log.info("test close start()");

            }

            /**
             * 发送请求报文
             * @param message message to be sent to the server.
             */
            @Override
            public void sendMessage(ReqT message) {
                log.info("test open sendMessage():{}",ObjectMapperUtils.toJSON(message));
                super.sendMessage(message);
                log.info("test close sendMessage()");

            }

            @Override
            public void halfClose() {
                log.info("test open halfClose()");
                super.halfClose();
                log.info("test close halfClose()");
            }

            @Override
            public void cancel(String message, Throwable cause) {
                log.info("test open cancel()");
                super.cancel(message, cause);
                log.info("test close cancel()");
            }
        };
    }
}

如果仅仅是将ThreadLocal的值通过header向下传递,可以这样重写:

    @Override
    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method,
                                                               CallOptions callOptions, Channel next) {
        return new SimpleForwardingClientCall<ReqT, RespT>(next.newCall(method, callOptions)) {

            @Override
            public void start(Listener<RespT> responseListener, Metadata headers) {
                injectHeadersFromScope(headers);
                // @1 在Header中设置需要透传的值
                super.start(responseListener, headers);
            }
        };
    }
上一篇 下一篇

猜你喜欢

热点阅读