优雅地替换请求头
2021-12-24 本文已影响0人
guessguess
最近遇到一个需要替换请求头内容的功能。
考虑到业务代码中使用到请求头的业务代码量十分巨大。此外这样子调整的话代码的侵入性很高,所以,最终目的还是想着从请求的内容进行调整。
最后决定使用过滤器。那么为什么使用过滤器,因为希望等到请求真正被开始处理的时候,使用的都是被替换过的请求内容。
springmvc中对于请求的处理流程如下
request->filter->进入DispatcherServlet->通过urlhandlermapping找到处理器链(拦截器+本身的业务方法)
->转成适配器->通过适配器处理->拦截器pre->处理器---业务方法->生成视图->拦截器post->拦截器afterCompletion
->发送事件->将response返回
为什么选择过滤器
过滤器与拦截器本身有一个很大的区别,过滤器执行的过程中并没有进入DispatcherServlet.
如果一个请求进入DispatcherServlet被处理,那么我会认为这个请求才真正开始被处理。
所以选择过滤器,是考虑在真的被处理器之前,去完成对应的替换。
那么如何实现?
首先来看看相关的类结构
相关结构
其实比较简单。ServletRequest就是请求对应的接口。ServletRequestWrapper则是起了一个包装的作用。
至于有什么用?说白了就是提供拓展。
那么直接看看相关类的源码
ServletRequest
public interface ServletRequest {
public Object getAttribute(String name);
public Enumeration<String> getAttributeNames();
public String getCharacterEncoding();
public void setCharacterEncoding(String env) throws UnsupportedEncodingException;
public int getContentLength();
public long getContentLengthLong();
public String getContentType();
public ServletInputStream getInputStream() throws IOException;
public String getParameter(String name);
public Enumeration<String> getParameterNames();
public String[] getParameterValues(String name);
public Map<String, String[]> getParameterMap();
public String getProtocol();
public String getScheme();
public String getServerName();
public int getServerPort();
public BufferedReader getReader() throws IOException;
public String getRemoteAddr();
public String getRemoteHost();
public void setAttribute(String name, Object o);
public void removeAttribute(String name);
public Locale getLocale();
public Enumeration<Locale> getLocales();
public boolean isSecure();
public RequestDispatcher getRequestDispatcher(String path);
public String getRealPath(String path);
public int getRemotePort();
public String getLocalName();
public String getLocalAddr();
public int getLocalPort();
public ServletContext getServletContext();
public AsyncContext startAsync() throws IllegalStateException;
public AsyncContext startAsync(ServletRequest servletRequest,
ServletResponse servletResponse)
throws IllegalStateException;
public boolean isAsyncStarted();
public boolean isAsyncSupported();
public AsyncContext getAsyncContext();
public DispatcherType getDispatcherType();
}
方法很多,都是获取请求相关内容的接口。
ServletRequestWrapper
public class ServletRequestWrapper implements ServletRequest {
private ServletRequest request;
public ServletRequestWrapper(ServletRequest request) {
if (request == null) {
throw new IllegalArgumentException("Request cannot be null");
}
this.request = request;
}
public String getContentType() {
return this.request.getContentType();
}
。。。。若干方法
}
这个类的作用其实就是将ServletRequest封装成成员变量。
各个方法的实现都是利用ServletRequest本身的方法去完成的。
HttpServletRequestWrapper
HttpServletRequestWrapper与ServletRequestWrapper也是一样的,只不过包装的是HttpServletRequest。
源码如下
public class HttpServletRequestWrapper extends ServletRequestWrapper implements HttpServletRequest {
public HttpServletRequestWrapper(HttpServletRequest request) {
super(request);
}
@Override
public String getHeader(String name) {
return this._getHttpServletRequest().getHeader(name);
}
private HttpServletRequest _getHttpServletRequest() {
return (HttpServletRequest) super.getRequest();
}
}
利用HttpServletRequestWrapper完成拓展。
源码如下
private class ModifyParametersWrapper extends HttpServletRequestWrapper {
private final Map<String, String> customHeaders;
ModifyParametersWrapper(HttpServletRequest request) {
super(request);
this.customHeaders = new HashMap<>();
}
void putHeader(String name, String value) {
this.customHeaders.put(name, value);
}
@Override
public String getHeader(String name) {
String headerValue = customHeaders.get(name);
if (headerValue != null) {
return headerValue;
}
return ((HttpServletRequest) getRequest()).getHeader(name);
}
@Override
public Enumeration<String> getHeaders(String name) {
String headerValue = customHeaders.get(name);
Set<String> set = new HashSet<String>();
if (headerValue != null) {
set.add(headerValue);
return Collections.enumeration(set);
}
return ((HttpServletRequest) getRequest()).getHeaders(name);
}
@Override
public Enumeration<String> getHeaderNames() {
Set<String> set = new HashSet<>(customHeaders.keySet());
Enumeration<String> e = ((HttpServletRequest) getRequest()).getHeaderNames();
while (e.hasMoreElements()) {
String n = e.nextElement();
set.add(n);
}
return Collections.enumeration(set);
}
}
这个类的实现也比较简单,其实说白了就是用一个新的成员变量来存储请求头。
优先获取新的请求头,其次获取旧的请求头。
结合过滤器使用
@WebFilter("/test/*")
public class ProjectConvertFilter implements Filter{
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
HttpServletRequest hsr = (HttpServletRequest)request;
ModifyParametersWrapper mpw = new ModifyParametersWrapper(hsr);
Integer oldprojectId = Integer.valueOf(hsr.getHeader("projectID"));
Integer newProjectId = oldprojectId + 1;
mpw.putHeader("projectID", newProjectId);
chain.doFilter(mpw, response);
}
}
如何无侵入使用
结合自动装配即可。
@Configuration
@ServletComponentScan(basePackages = "xxx")
public class WebFilterAutoConfiguration {
}