使用Filter以及RequestWrapper实现参数的动态添
熟悉java web开发的人都知道,当我们发起了一个http请求后,我们在服务端就很难动态的修改请求的数据了,因为HttpServletRequest是一个接口,而其实现类又因不同的web服务器不同,而实现不同,所以我们修改其实不同服务器的实现类也是一个不明智的想法。
但是最近我遇到一个业务场景,就是对每一个Http请求,都需要一个特定的参数,这个参数是用来追溯日志用的,于是团队刚开始的想法就是每次请求让前端传入一个uuid到后台,但是这样的做法让我感觉很费劲,于是我就想到了HttpServletWrapper。
HttpServletWrapper是HttpServletRequest的一个实现类(装饰器),由于多态的原因,我们可以使用该类,实现对父类方法的重写,使用该类可以在请求下发的时候,动态修改请求的内容。因此我便通过此类实现了对请求参数的扩展。
当然在这里,我们要熟悉如何通过HttpServletRequest获取我们的请求参数,如果我们获取get请求参数,我们一般通过request.getParameter(String name)的方法,如果获取post请求参数,我们就要区分post的请求方式了,如果使用form-data以及x-www-form-urlencoded,我们也可以通过上面的方法获取到,因为这些请求参数都会放入到request的一个请求参数Map中。但是如果我们把请求放入@Requestbody中呢,那请求参数就是会以json的形式存在stream流中了。而且ServletInputStream以及BufferedReader只能读取一次,读取一次后内容就会置空导致异常,因此我们要对这些流进行处理。相信看以下代码吧
(1)实现对请求的过滤
public class ParamServletFilter extends AbstractLogger implements Filter {
@Override
public void init(FilterConfig filterConfig) {
}
@Override
public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain)
throws IOException, ServletException {
HttpServletRequest request = (HttpServletRequest) req;
String uuid = GlobalUtils.generateUUID();
logger.info("请求路径: [{}], 请求的全局id: [{}]", request.getServletPath(), uuid);
chain.doFilter(new ParamServletRequestWrapper(request, uuid), resp);
}
@Override
public void destroy() {
}
}
(2) 实现对参数的扩展
public class ParamServletRequestWrapper extends HttpServletRequestWrapper {
private static final Logger LOGGER = LoggerFactory.getLogger(ParamServletRequestWrapper.class);
private Map<String , String[]> params = new HashMap<>();
private InputStream inputStream;
private ParamServletRequestWrapper(HttpServletRequest request) {
super(request);
}
ParamServletRequestWrapper(HttpServletRequest request, String correlationId) {
this(request);
if (!CollectionUtils.isEmpty(request.getParameterMap())) {
this.params.putAll(request.getParameterMap());
}
addParameter(correlationId);
dealRequestBody(request, correlationId);
}
/**
* 获取所有的参数名
*
* @return Map的key -> 参数名
*/
@Override
public Enumeration<String> getParameterNames() {
LOGGER.debug("getParameterNames: [{}]", JsonUtils.serialize(Collections.enumeration(params.keySet())));
return Collections.enumeration(params.keySet());
}
/**
* 根据参数名获取参数值
*
* @param name 参数名
* @return 参数传入值
*/
@Override
public String getParameter(String name) {
String[]values = params.get(name);
if (values == null || values.length == 0) {
return null;
}
LOGGER.debug("getParameter, name: [{}], value: [{}]", name, values[0]);
return values[0];
}
/**
* 根据参数名获取所有的参数值
*
* @param name 参数名
* @return 参数所有的值
*/
@Override
public String[] getParameterValues(String name) {
LOGGER.debug("getParameterValues, name: [{}], value: [{}]", name, params.get(name));
return params.get(name);
}
/**
* 获取所有的请求参数
*
* @return 请求参数
*/
@Override
public Map<String, String[]> getParameterMap() {
LOGGER.debug("getParameterMap, parameterMap: [{}]", JsonUtils.serialize(params));
return this.params;
}
/**
* 获取body的数据流
*
* @return {@link DefaultServletInputStream}
* @throws IOException io异常
*/
@Override
public ServletInputStream getInputStream() throws IOException {
return Objects.isNull(inputStream) ? null : new DefaultServletInputStream(inputStream);
}
/**
* 获取body内容的数据流
*
* @return {@link BufferedReader}
* @throws IOException io异常
*/
@Override
public BufferedReader getReader() throws IOException {
return Objects.isNull(inputStream) ? null : new BufferedReader(new InputStreamReader(inputStream));
}
/**
* 处理请求body的内容
*
* @param request 请求
* @param correlationId 全局id
*/
private void dealRequestBody(HttpServletRequest request, String correlationId) {
try {
ServletInputStream inputStream = request.getInputStream();
if (Objects.isNull(inputStream)) {
return;
}
String inputStr = IOUtils.toString(inputStream, Constants.UTF_8);
if (StringUtils.isEmpty(inputStr)) {
return;
}
LOGGER.debug("getInputStream, inputString: [{}]", inputStr);
JSONObject inputJson = JSONObject.parseObject(inputStr);
if (Objects.isNull(inputJson)) {
return;
}
//String encryptJson = inputJson.getString(Constants.ENCRYPT_CONTENT);
//String decryptJson = EncryptionUtils.decrypt(encryptJson);
//inputJson = JSONObject.parseObject(decryptJson);
if (!inputJson.containsKey(Constants.CORRELATION_ID)) {
inputJson.put(Constants.CORRELATION_ID, correlationId);
}
LOGGER.debug("getInputStream, inputJson: [{}]", inputJson.toJSONString());
this.inputStream = IOUtils.toInputStream(inputJson.toJSONString(), Constants.UTF_8);
} catch (Exception e) {
LOGGER.error("getInputStream failed, cause: [{}]", ExceptionUtils.getFullStackTrace(e));
}
}
/**
* 对map中加密字段进行解析
*
* 加密规则将所有的字段通过AES加密到encryptContent字段中
* 然后后端进行解密
*/
private void analysisMap() {
String[] contents = this.params.get(Constants.ENCRYPT_CONTENT);
if (GlobalUtils.isNull(contents)) {
return;
}
String content = contents[0];
String decrypt = EncryptionUtils.decrypt(content);
if (StringUtils.isEmpty(decrypt)) {
return;
}
String[] split = decrypt.split("&");
if (GlobalUtils.isNull(split)) {
return;
}
for (String query : split) {
String[] queryString = query.split("=");
if (!GlobalUtils.isNull(queryString) && queryString.length == 2) {
params.put(queryString[0], new String[] {queryString[1]});
}
}
}
/**
* 对map的值赋予新的值
*
* @param value 添加新的值
*/
private void addParameter(Object value) {
//analysisMap();
if (params.containsKey(Constants.CORRELATION_ID)) {
return;
}
if (value != null) {
if (value instanceof String[]) {
params.put(Constants.CORRELATION_ID, (String[]) value);
} else if (value instanceof String) {
params.put(Constants.CORRELATION_ID, new String[] {(String) value});
} else {
params.put(Constants.CORRELATION_ID, new String[] {String.valueOf(value)});
}
}
}
public class DefaultServletInputStream extends ServletInputStream {
private final InputStream sourceStream;
private boolean finished = false;
DefaultServletInputStream(InputStream sourceStream) {
ParamUtils.assertNotNull(sourceStream, "inputStream must not be null");
this.sourceStream = sourceStream;
}
public int read() throws IOException {
int data = this.sourceStream.read();
if (data == -1) {
this.finished = true;
}
return data;
}
public int available() throws IOException {
return this.sourceStream.available();
}
public void close() throws IOException {
super.close();
this.sourceStream.close();
}
public boolean isFinished() {
return this.finished;
}
public boolean isReady() {
return true;
}
public void setReadListener(ReadListener readListener) {
throw new UnsupportedOperationException();
}
}
}
注意 !!!
Enumeration<String> getParameterNames 此方法必须实现,不然添加的map的参数无法获取,以及本例中IOUtils使用了apache的common-io,以及json处理使用了fastJson,如果你的HttpMessageConverters不兼容,请替换成对应的Json处理器