优雅地替换请求头

2021-12-24  本文已影响0人  guessguess

最近遇到一个需要替换请求头内容的功能。
考虑到业务代码中使用到请求头的业务代码量十分巨大。此外这样子调整的话代码的侵入性很高,所以,最终目的还是想着从请求的内容进行调整。

最后决定使用过滤器。那么为什么使用过滤器,因为希望等到请求真正被开始处理的时候,使用的都是被替换过的请求内容。

springmvc中对于请求的处理流程如下
request->filter->进入DispatcherServlet->通过urlhandlermapping找到处理器链(拦截器+本身的业务方法)
->转成适配器->通过适配器处理->拦截器pre->处理器---业务方法->生成视图->拦截器post->拦截器afterCompletion
->发送事件->将response返回

为什么选择过滤器

过滤器与拦截器本身有一个很大的区别,过滤器执行的过程中并没有进入DispatcherServlet.
如果一个请求进入DispatcherServlet被处理,那么我会认为这个请求才真正开始被处理。
所以选择过滤器,是考虑在真的被处理器之前,去完成对应的替换。

那么如何实现?

首先来看看相关的类结构


相关结构

其实比较简单。ServletRequest就是请求对应的接口。ServletRequestWrapper则是起了一个包装的作用。
至于有什么用?说白了就是提供拓展。
那么直接看看相关类的源码

ServletRequest

public interface ServletRequest {
    public Object getAttribute(String name);
    public Enumeration<String> getAttributeNames();
    public String getCharacterEncoding();
    public void setCharacterEncoding(String env) throws UnsupportedEncodingException;
    public int getContentLength();
    public long getContentLengthLong();
    public String getContentType();
    public ServletInputStream getInputStream() throws IOException; 
    public String getParameter(String name);
    public Enumeration<String> getParameterNames();
    public String[] getParameterValues(String name);
    public Map<String, String[]> getParameterMap();  
    public String getProtocol();
    public String getScheme();
    public String getServerName();
    public int getServerPort();
    public BufferedReader getReader() throws IOException;
    public String getRemoteAddr();
    public String getRemoteHost();
    public void setAttribute(String name, Object o);
    public void removeAttribute(String name);
    public Locale getLocale();
    public Enumeration<Locale> getLocales();
    public boolean isSecure();
    public RequestDispatcher getRequestDispatcher(String path);
    public String getRealPath(String path); 
    public int getRemotePort();
    public String getLocalName();
    public String getLocalAddr();
    public int getLocalPort();
    public ServletContext getServletContext();
    public AsyncContext startAsync() throws IllegalStateException;
    public AsyncContext startAsync(ServletRequest servletRequest,
                                   ServletResponse servletResponse)
            throws IllegalStateException;
    public boolean isAsyncStarted();
    public boolean isAsyncSupported();
    public AsyncContext getAsyncContext();
    public DispatcherType getDispatcherType();
}

方法很多,都是获取请求相关内容的接口。

ServletRequestWrapper

public class ServletRequestWrapper implements ServletRequest {

    private ServletRequest request;

    public ServletRequestWrapper(ServletRequest request) {
        if (request == null) {
            throw new IllegalArgumentException("Request cannot be null");   
        }
        this.request = request;
    }
    public String getContentType() {
        return this.request.getContentType();
    }
    。。。。若干方法
}

这个类的作用其实就是将ServletRequest封装成成员变量。
各个方法的实现都是利用ServletRequest本身的方法去完成的。

HttpServletRequestWrapper

HttpServletRequestWrapper与ServletRequestWrapper也是一样的,只不过包装的是HttpServletRequest。
源码如下

public class HttpServletRequestWrapper extends ServletRequestWrapper implements HttpServletRequest {
    public HttpServletRequestWrapper(HttpServletRequest request) {
        super(request);
    }

    @Override
    public String getHeader(String name) {
        return this._getHttpServletRequest().getHeader(name);
    }
    private HttpServletRequest _getHttpServletRequest() {
        return (HttpServletRequest) super.getRequest();
    }
}

利用HttpServletRequestWrapper完成拓展。

源码如下

    private class ModifyParametersWrapper extends HttpServletRequestWrapper {
        private final Map<String, String> customHeaders;

        ModifyParametersWrapper(HttpServletRequest request) {
            super(request);
            this.customHeaders = new HashMap<>();
        }

        void putHeader(String name, String value) {
            this.customHeaders.put(name, value);
        }
        
        @Override
        public String getHeader(String name) {
            String headerValue = customHeaders.get(name);
            if (headerValue != null) {
                return headerValue;
            }
            return ((HttpServletRequest) getRequest()).getHeader(name);
        }
        
        @Override
        public Enumeration<String> getHeaders(String name) {
            String headerValue = customHeaders.get(name);
            Set<String> set = new HashSet<String>();
            if (headerValue != null) {
                set.add(headerValue);
                return Collections.enumeration(set);
            }
            return ((HttpServletRequest) getRequest()).getHeaders(name);
        }
        
        @Override
        public Enumeration<String> getHeaderNames() {
            Set<String> set = new HashSet<>(customHeaders.keySet());
            Enumeration<String> e = ((HttpServletRequest) getRequest()).getHeaderNames();
            while (e.hasMoreElements()) {
                String n = e.nextElement();
                set.add(n);
            }
            return Collections.enumeration(set);
        }
    }

这个类的实现也比较简单,其实说白了就是用一个新的成员变量来存储请求头。
优先获取新的请求头,其次获取旧的请求头。

结合过滤器使用

@WebFilter("/test/*")
public class ProjectConvertFilter implements Filter{
    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
            throws IOException, ServletException {
        HttpServletRequest hsr = (HttpServletRequest)request;
        ModifyParametersWrapper mpw = new ModifyParametersWrapper(hsr);
        Integer oldprojectId = Integer.valueOf(hsr.getHeader("projectID"));
        Integer newProjectId = oldprojectId + 1;
        mpw.putHeader("projectID", newProjectId);
        chain.doFilter(mpw, response);
    }
}

如何无侵入使用

结合自动装配即可。

@Configuration
@ServletComponentScan(basePackages = "xxx")
public class WebFilterAutoConfiguration {

}
上一篇 下一篇

猜你喜欢

热点阅读