Spring 实现一个支持 分发 转发的 rest接口

2019-02-16  本文已影响0人  Yellowtail

需求概述

我们的需求是这样的:
后台的rest接口对接 安卓、IOS客户端、小程序
考虑到安全,我们对请求参数做了签名校验,对部分接口做了登录态校验
安全是安全了,但是 对于我们后台开发人员来说,进行接口测试就变得困难起来

接口测试,我们前面是这么做的

阶段一

在本地进行接口测试,把关键的Filter 注释掉
感受:太麻烦了,不小心就会把注释的代码提交了(我用的是乌龟git)

阶段二

我用Python写了一个本地代理,postman 配置好代理,让请求走我的代理
这个代理会模拟客户端,对参数进行签名
工程在这里:proxy
感受:我们有些接口使用到了Protobuf,因为传输的是二进制数据流,接口返回的数据,代理想转为可视化的json数据很费劲
(目前是先用protobuf生成Python文件,再反序列化,再显示,生成文件这一步需要经常维护)
而且我们的登录态用到了jwttoken动不动就过期了,需要更换;虽然使用postman的全局变量可以减少更换次数,但是如果想切换用户查看接口返回结果,又需要去数据库找到对应的token,再替换,有点费劲

经历过这两个阶段,后面就想着,能不能开发出这样一个接口呢:
参数里写上url method userId
然后这个接口根据这些参数,进行分发(转发),然后在Filter里面排除这个接口,不就避免了签名校验 登陆校验吗?

如何实现

刚开始实现的时候,也是一脸懵逼,用“转发” 关键字 搜了一番,发现都不能用在Rest 接口
最后没办法,想着去了解一下spring是怎么实现的,我再重复实现一遍不就OK了吗
于是去看了看 DispatcherServlet 源码解析文章,找到了思路

先看看 DispatcherServlet 的 核心方法 doDispatch

protected void doDispatch(HttpServletRequest request, HttpServletResponse response) throws Exception {
        HttpServletRequest processedRequest = request;
        HandlerExecutionChain mappedHandler = null;
        boolean multipartRequestParsed = false;

        WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request);

        try {
            ModelAndView mv = null;
            Exception dispatchException = null;

            try {
                processedRequest = checkMultipart(request);
                multipartRequestParsed = (processedRequest != request);

                // Determine handler for the current request.
                                // 找到当前request对应的 handler
                mappedHandler = getHandler(processedRequest);
                if (mappedHandler == null || mappedHandler.getHandler() == null) {
                    noHandlerFound(processedRequest, response);
                    return;
                }

                // Determine handler adapter for the current request.
                                // 找到当前 handler 对应的 适配器
                HandlerAdapter ha = getHandlerAdapter(mappedHandler.getHandler());

                // Process last-modified header, if supported by the handler.
                String method = request.getMethod();
                boolean isGet = "GET".equals(method);
                if (isGet || "HEAD".equals(method)) {
                    long lastModified = ha.getLastModified(request, mappedHandler.getHandler());
                    if (logger.isDebugEnabled()) {
                        logger.debug("Last-Modified value for [" + getRequestUri(request) + "] is: " + lastModified);
                    }
                    if (new ServletWebRequest(request, response).checkNotModified(lastModified) && isGet) {
                        return;
                    }
                }

                                //执行 HandlerExecutionChain 里面 拦截器的 preHandle
                if (!mappedHandler.applyPreHandle(processedRequest, response)) {
                    return;
                }

                // Actually invoke the handler.
                                // 让适配器运行handler,也就是执行 Controller里的某个具体的方法
                mv = ha.handle(processedRequest, response, mappedHandler.getHandler());

                if (asyncManager.isConcurrentHandlingStarted()) {
                    return;
                }

                applyDefaultViewName(processedRequest, mv);

                                //执行 HandlerExecutionChain 里面 拦截器的 postHandle
                mappedHandler.applyPostHandle(processedRequest, response, mv);
            }
            catch (Exception ex) {
                dispatchException = ex;
            }
            catch (Throwable err) {
                // As of 4.3, we're processing Errors thrown from handler methods as well,
                // making them available for @ExceptionHandler methods and other scenarios.
                dispatchException = new NestedServletException("Handler dispatch failed", err);
            }
            processDispatchResult(processedRequest, response, mappedHandler, mv, dispatchException);
        }
        catch (Exception ex) {
            triggerAfterCompletion(processedRequest, response, mappedHandler, ex);
        }
        catch (Throwable err) {
            triggerAfterCompletion(processedRequest, response, mappedHandler,
                    new NestedServletException("Handler processing failed", err));
        }
        finally {
            if (asyncManager.isConcurrentHandlingStarted()) {
                // Instead of postHandle and afterCompletion
                if (mappedHandler != null) {
                    mappedHandler.applyAfterConcurrentHandlingStarted(processedRequest, response);
                }
            }
            else {
                // Clean up any resources used by a multipart request.
                if (multipartRequestParsed) {
                    cleanupMultipart(processedRequest);
                }
            }
        }
    }

关键的地方我都写了中文注释

所以 我们自己实现 分发 效果的时候,参考这个逻辑即可
找到handler --> 找到适配器 --> 执行拦截器 preHandle(如果需要) --> 执行 handler --> 执行拦截器 postHandle(如果需要)

实现代码

下面是 Controller

package com.xxx.app.skmr.controller;

import java.util.HashMap;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.apache.commons.lang3.StringUtils;
import org.powermock.reflect.Whitebox;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.servlet.DispatcherServlet;
import org.springframework.web.servlet.HandlerAdapter;
import org.springframework.web.servlet.HandlerExecutionChain;
import org.springframework.web.servlet.ModelAndView;

import com.xxx.app.skmr.constant.Constants;
import com.xxx.app.skmr.filter.EncryptRequest;
import com.xxx.app.skmr.properties.CommonConfigProperties;
import com.xxx.app.skmr.service.ReqCacheService;
import com.xxx.app.skmr.util.AssertHelper;

@RequestMapping("/v1")
@RestController
public class InterfaceTestController {
    
    private static final Logger LOGGER = LoggerFactory.getLogger(InterfaceTestController.class);
    
    @Autowired
    private DispatcherServlet dispatcherServlet;
    
    @Autowired
    private CommonConfigProperties commonConfigProperties;

    @RequestMapping(value = "/iTest", method = RequestMethod.POST)
    public void receiveRequest(
            @RequestParam(value="url", required=true) String url,
            @RequestParam(value="method", required=true) String method,
            @RequestParam(value="param", required=false) String param,
            @RequestParam(value="userId", required=false) String userId,
            HttpServletRequest request, 
            HttpServletResponse response) {
        
        LOGGER.info("InterfaceTestController receiveRequest, url={}", url);
        
        if (! Constants.SKMR_APP_ENV_NAME_DEV.equals(commonConfigProperties.getEnvName())) {
            //不是 dev 环境,直接退出
            return;
        }
        
        //设置登录态
        if (StringUtils.isNotBlank(userId)) {
            ReqCacheService.setReqUserId(request, userId);
        }
        
        //改变request里的值
        EncryptRequest myRequest = (EncryptRequest) request;
        
        //强行设置 header 里的 accept 为 application/json,为客户端省点事
        myRequest.setHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE);
        
        myRequest.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE);
        
        //修改方法,方法要大写
        myRequest.setMethod(method.toUpperCase());
        
        //修改 请求path
        myRequest.setRequestURI(url);
        
        //这个必须要设置,不然在 UrlPathHelper.getLookupPathForRequest()  方法里面 有问题,需要让rest 为空
        myRequest.setServletPath(url);
        
        //设置url 查询参数,
        if (StringUtils.isNotBlank(param)) {
            HashMap<String, String> convertParam = convertParam(param);
            
            myRequest.setUseExternalParam(true);
            myRequest.setParamMap(convertParam);
        }
        
        //转发
        redirect(myRequest, response);
        
        return ;
    }
    
    @RequestMapping(value = "/iTest", method = RequestMethod.GET)
    public void receiveRequestV2(
            @RequestParam(value="url", required=true) String url,
            @RequestParam(value="method", required=true) String method,
            @RequestParam(value="param", required=false) String param,
            @RequestParam(value="userId", required=false) String userId,
            HttpServletRequest request, 
            HttpServletResponse response) {
        
        //这个接口因为是 get 的,所以可以通过浏览器直接调用,不需要postman
        //当然了,对于有body的,浏览器就不行了,还是需要postman
        
        LOGGER.info("InterfaceTestController receiveRequestV2, url={}", url);
        
        receiveRequest(url, method, param, userId, request, response);
        
        return ;
    }
    
    /**
     * <br>转发 分发
     * <br>只有自身的url 会走 过滤器,转发后的 没有走过滤器
     *
     * @param request
     * @param response
     * @author YellowTail
     * @since 2019-02-15
     */
    private void redirect(HttpServletRequest request, HttpServletResponse response) {
        
        try {
            // 1. 得到 HandlerExecutionChain, 调用方法 getHandler 即可得到
            HandlerExecutionChain handlerExecutionChain = Whitebox.invokeMethod(dispatcherServlet, "getHandler", request);
            
            // 2. 取出 HandlerMethod,适配器要用
            HandlerMethod handlerMethod = (HandlerMethod) handlerExecutionChain.getHandler();
            
            // 3. 得到 适配器 HandlerAdapter,调用方法 getHandlerAdapter 得到
            HandlerAdapter ha = Whitebox.invokeMethod(dispatcherServlet, "getHandlerAdapter", handlerMethod);
            
            // 4. 执行 HandlerExecutionChain 拦截器的 preHandler() 前置方法, CmdHandlerInterceptor 会去设置 cmd
            Whitebox.invokeMethod(handlerExecutionChain, "applyPreHandle", request, response);
            
            // 5. 执行 handler
            ModelAndView mv = ha.handle(request, response, handlerMethod);
            
            // 6. 执行 拦截器的 postHandler() 方法, LogHandlerInterceptor 会去记录日志
            Whitebox.invokeMethod(handlerExecutionChain, "applyPostHandle", request, response, mv);
            
        } catch (Exception e) {
            LOGGER.error("error, ", e);
        }
    }
    
    /**
     * <br>将字符串形式的 参数  id=test&type=2 转换为 map,方便使用
     *
     * @param param
     * @return
     * @author YellowTail
     * @since 2019-02-15
     */
    private HashMap<String, String> convertParam(String param) {
        HashMap<String, String> map = new HashMap<>();
        
        if (StringUtils.isBlank(param)) {
            return map;
        }
        
        String[] split = param.split("&");
        
        for(String eachParam : split) {
            String[] split2 = eachParam.split("=");
            
            AssertHelper.assertTrue(2 >= split2.length , eachParam + " eachParam should contains one =");
            
            if (2 == split2.length) {
                map.put(split2[0], split2[1]);
            }
        }
        
        return map;
    }
}

因为我们需要对 request 做很多操作,所以必须自己实现一个request ,且继承 HttpServletRequestWrapper

package com.xxx.app.skmr.filter;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.StringReader;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Vector;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.xxx.app.skmr.util.IOSystem;
import com.xxx.app.skmr.util.StringUtils;


public class EncryptRequest extends HttpServletRequestWrapper {
    
    private static final Logger LOGGER = LoggerFactory.getLogger(EncryptRequest.class);
    
    /**
     * URL 的方法
     */
    private String _method;
    
    /**
     * URI
     */
    private String _requestURI;
    
    /**
     * ServletPath
     */
    private String _servletPath;
    
    
    /**
     * 是否使用 扩展的 参数
     */
    private boolean useExternalParam = false;
    
    /**
     * 参数 map
     */
    private HashMap<String, String> paramMap;
    
    /**
     * header
     */
    private HashMap<String, String> _headers;
    
    /**
     * 存储requestBody的内容
     */
    private byte[] requestBody = null;
    

   public EncryptRequest(HttpServletRequest request) {
       super(request);
  
       try {
           this.requestBody = IOSystem.readToBytes(request.getInputStream());
       } catch (IOException e) {
           e.printStackTrace();
           LOGGER.error("EncryptRequest init getInputStream error", e);
           throw new RuntimeException(e);
       }
   }

    /**
     * 获取requestbody
     */
    public byte[] getRequestBody() {
        return this.requestBody;
    }
    
    public String getRequestBodyString() {
        return StringUtils.toString(requestBody, this.getRequest().getCharacterEncoding());
    }
    
    public void setRequestBody(byte[] requestBody ) {
        this.requestBody = requestBody;
    }


    @Override
    public ServletInputStream getInputStream() throws IOException {
        // 如果是null 证明首次数据获取失败
        if (requestBody == null) {
            requestBody = new byte[0];
        }
        
        final ByteArrayInputStream bais = new ByteArrayInputStream(requestBody);
        return new ServletInputStream() {
            @Override
            public int read() throws IOException {
                return bais.read();
            }

            @Override
            public boolean isFinished() {
                return false;
            }

            @Override
            public boolean isReady() {
                return true;
            }

            @Override
            public void setReadListener(ReadListener listener) {
            }
        };
    }


    @Override
    public BufferedReader getReader() throws IOException {
        return new BufferedReader(new StringReader(this.getRequestBodyString()));
    }
    
    /**
     * 重写 getMethod,方便 变更 method
     */
    @Override
    public String getMethod() {
        if (null == _method) {
            _method = super.getMethod();
        }
        return _method;
    }
    
    /**
     * <br>扩展方法:支持修改  method
     * <br>在 接口测试的接口里使用到了
     *
     * @param newMethod
     * @author YellowTail
     * @since 2019-02-14
     */
    public void setMethod(String newMethod) {
        _method = newMethod;
    }
    
    /**
     * 覆盖,方便变更
     */
    @Override
    public String getRequestURI() {
        if (null == _requestURI) {
            _requestURI = super.getRequestURI();
        }
        return _requestURI;
    }
    
    /**
     * <br>扩展方法,支持变更 path
     * <br>在 接口测试的接口里使用到了
     *
     * @param value
     * @author YellowTail
     * @since 2019-02-14
     */
    public void setRequestURI(String value) {
        _requestURI = value;
    }
    
    @Override
    public String getServletPath() {
        if (null == _servletPath) {
            _servletPath = super.getServletPath();
        }
        return _servletPath;
    }
    
    public void setServletPath(String value) {
        _servletPath = value;
    }
    
    @Override
    public String[] getParameterValues(String name) {
        if (useExternalParam) {
            String string = paramMap.get(name);
            if (null == string) {
                return null;
            }
            
            return new  String [] {string};
        }
        return super.getParameterValues(name);
    }

    public void setUseExternalParam(boolean useExternalParam) {
        this.useExternalParam = useExternalParam;
    }

    public void setParamMap(HashMap<String, String> paramMap) {
        this.paramMap = paramMap;
    }
    
    
    /**
     * 覆盖 header 获取的方法,进行自定义扩展
     */
    @Override
    public Enumeration<String> getHeaders(String name) {
        
        //如果某个值被设置过,那么用自定义的值
        if (null != _headers && _headers.containsKey(name)) {
            String string = _headers.get(name);
            
            Vector<String> values = new Vector<String>();
            values.add(string);
            
            return values.elements();
        }
        
        return super.getHeaders(name);
    }
    
    /**
     * <br>设置 Header 里的值
     *
     * @param key
     * @param value
     * @author YellowTail
     * @since 2019-02-15
     */
    public void setHeader(String key, String value) {
        if (null == _headers) {
            _headers = new HashMap<>();
        }
        
        _headers.put(key, value);
    }
   
}

实现代码解释

因为 DispatcherServlet 的很多方法都是 protected, friendly 懒的搞继承,直接反射调用了
关于拦截器,因为我在实现的时候,是需要拦截器的一些效果(自动上传日志),所以就执行了步骤 4 和 6
大家在实现的时候,可以根据实际情况进行取舍

为何自定义request
因为 HttpServletRequest 只有一堆的 get 方法,没有 set 方法
看了下实现,反射好费劲,算了,直接继承一个,复写方法

效果

Type value
接口地址 /v1/iTest
接口方法 Post
接口Header header里的Accept Content-Type 都不需要设置,代码已经写死为application/json
设置为其他值不会生效
接口参数 是否必填 解释
url 必填 准备请求哪个接口
method 必填 接口的方法(因为有些接口url一样,但是method不一样),
大小写不敏感, get Get GET 都行
userId 非必填 设置登录态,即想用哪个用户请求接口
param 非必填 请求接口的请求参数,比如对于接口 /v1/xxx/me/list?unitId=2&nextid=&scope=2
那么param就是?后面的字符串,且需要进行url编码
也就是unitId%3D2%26nextid%3D%26scope%3D2

BUG 修复 2019年2月20日 10:31:35

修复了 param 里面 参数值为空抛异常的问题
代码已在此博客里更新

功能新增:把异常输出到浏览器上,省去看日志的步骤

一旦代码抛了异常,浏览器调用接口的时候,看不到信息
于是突发奇想,把 Exception 信息 参考 Logger 那种方式输出到屏幕上,多方便
于是写了一下,代码在下面,没有更新到 文章开始的那个代码块里

public static final String CHANGE_LINE = "\n".intern();
public static final String TAB_INDENT = "    at ";

...
catch (Exception e) {
            LOGGER.error("error, ", e);
            
            StringBuilder sb = new StringBuilder();
            
            sb.append(e.toString()).append(CHANGE_LINE);
            
            for(StackTraceElement st: e.getStackTrace()) {
                sb.append(TAB_INDENT)
                    .append(st.toString())
                    .append(CHANGE_LINE);
            }
            
            byte[] bytes = sb.toString().getBytes();
            HttpEncryptService.setResponseData(bytes, response, false);
        }
上一篇 下一篇

猜你喜欢

热点阅读