spring security验证码过滤器

2019-08-19  本文已影响0人  yunqing_71

图形验证码过滤器:

/**
 * 因为spring security的验证流程上没有验证图形验证码的过滤器,所以我们应该自己写一个过滤器Filter,
 * 是一个在UsernamePasswordAuthenticationFilter之前的过滤器,这两个过滤器处理同一个请求,先执行
 * 我们的图形验证码过滤器,如果验证通过转给UsernamePassword-----Filter执行,如果验证不过,就抛出异常。
 *
 * OncePerRequestFilter 是spring的一个工具类,保证过滤器每次只被调用一次
 * InitializingBean 实现它的目的是,在其他参数组装完毕之后初始化要拦截请求的值
 */
@Setter
@Getter
public class ValidateCodeFilter extends OncePerRequestFilter implements InitializingBean {

    /**
     * 身份认证失败处理器
     */
    private AuthenticationFailureHandler authenticationFailureHandler;

    /**
     * spring处理session的工具类
     */
    private SessionStrategy sessionStrategy = new HttpSessionSessionStrategy();

    /**
     * spring工具类可用来做url请求匹配
     */
    private AntPathMatcher pathMatcher = new AntPathMatcher();

    /**
     * 所有图形验证码要拦截的请求
     */
    private Set<String> urls = new HashSet<>();

    /**
     * 配置文件yunqing.security.........
     */
    private SecurityProperties securityProperties;

    @Override
    public void afterPropertiesSet() throws ServletException {
        super.afterPropertiesSet();
        /**
         * 把application配置中的url用,隔开放进一个String数组里,当用户配置了需要进行图片检验码验证的url时候执行
         */
        while(!"".equals(securityProperties.getCode().getImage().getUrl())) {
            String[] configUrls = StringUtils.splitByWholeSeparatorPreserveAllTokens(securityProperties.getCode().getImage().getUrl(), ",");

            for (String configUrl : configUrls) {
                urls.add(configUrl);
            }
        }
        /**
         * 登录请求一定要拦截的,所以加进去
         */
        urls.add("/yunqing/login");
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {

        /**
         * 如果配置的url有与当前请求的url匹配的,就进行图片验证
         */
        boolean action = false;
        for (String url : urls) {
            if(pathMatcher.match(url, request.getRequestURI())) {
                action = true;
            }
        }
        /**
         * action == true就执行图片验证码
         */
        if (action){

            try {

                validate(new ServletWebRequest(request));

            }catch (ValidateCodeException ex){
                /**
                 * 抛异常则执行登陆失败处理,并且返回,因为抛异常就不往下执行了
                 */
                authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
                return;
            }
        }
        /**
         * 如果不是登录请求,直接调用后面的过滤器
         */
        filterChain.doFilter(request,response);
    }

    /**
     * 验证图片验证码的逻辑
     * @param request
     */
    private void validate(ServletWebRequest request) throws ServletRequestBindingException {
        /**
         * 从session中取出之前生成验证码放进去的code
         */
        ImageCode codeInSession = (ImageCode) sessionStrategy.getAttribute(request, ValidateCodeProcessor.SESSION_KEY_PREFIX + "IMAGE");
        /**
         * 从请求中拿到用户填写的验证码imageCode
         */
        String codeInRequest = ServletRequestUtils.getStringParameter(request.getRequest(),"imageCode");

        if (StringUtils.isBlank(codeInRequest)) {
            throw new ValidateCodeException("验证码不能为空");
        }

        if (codeInSession == null) {
            throw new ValidateCodeException("验证码不存在");
        }
        /**
         * 判断验证码是否过期
         */
        if (codeInSession.isExpried()) {
            /**
             * 如果过期,把这个验证码从session删除
             */
            sessionStrategy.removeAttribute(request, ValidateCodeProcessor.SESSION_KEY_PREFIX + "IMAGE");
            throw new ValidateCodeException("验证码已过期");
        }

        if (!StringUtils.equals(codeInSession.getCode(), codeInRequest)) {
            throw new ValidateCodeException("验证码不匹配");
        }
        /**
         * 最后把验证码从session中删除
         */
        sessionStrategy.removeAttribute(request, ValidateCodeProcessor.SESSION_KEY_PREFIX + "IMAGE");
    }


}

短信验证码过滤器:

@Setter
@Getter
public class SmsCodeFilter extends OncePerRequestFilter implements InitializingBean {


    private AuthenticationFailureHandler authenticationFailureHandler;

    private SessionStrategy sessionStrategy = new HttpSessionSessionStrategy();

    /**
     * spring工具类
     */
    private AntPathMatcher pathMatcher = new AntPathMatcher();

    /**
     * 所有图形验证码要拦截的请求
     */
    private Set<String> urls = new HashSet<>();

    private SecurityProperties securityProperties;

    @Override
    public void afterPropertiesSet() throws ServletException {
        super.afterPropertiesSet();
        /**
         * 把application配置中的url用,隔开放进一个String数组里
         */
        String[] configUrls = StringUtils.splitByWholeSeparatorPreserveAllTokens(securityProperties.getCode().getSms().getUrl(), ",");

        for (String configUrl : configUrls) {
            urls.add(configUrl);
        }
        /**
         * 登录请求一定要拦截的,所以加进去
         */
        urls.add("/yunqing/mobile_login");
    }

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {

        /**
         * 如果配置的url有与当前请求的url匹配的,就进行图片验证
         */
        boolean action = false;
        for (String url : urls) {
            if(pathMatcher.match(url, request.getRequestURI())) {
                action = true;
            }
        }
        /**
         * action == true就执行图片验证码
         */
        if (action){

            try {

                validate(new ServletWebRequest(request));

            }catch (ValidateCodeException ex){
                /**
                 * 抛异常则执行登陆失败处理,并且返回,因为抛异常就不往下执行了
                 */
                authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
                return;
            }
        }
        /**
         * 如果不是登录请求,直接调用后面的过滤器
         */
        filterChain.doFilter(request,response);
    }

    /**
     * 验证图片验证码的逻辑
     * @param request
     */
    private void validate(ServletWebRequest request) throws ServletRequestBindingException {
        /**
         * 从session中取出之前生成验证码放进去的code
         */
        ValidateCode codeInSession = (ValidateCode) sessionStrategy.getAttribute(request, ValidateCodeProcessor.SESSION_KEY_PREFIX + "SMS");
        /**
         * 从请求中拿到用户填写的验证码imageCode
         */
        String codeInRequest = ServletRequestUtils.getStringParameter(request.getRequest(),"smsCode");

        if (StringUtils.isBlank(codeInRequest)) {
            throw new ValidateCodeException("验证码不能为空");
        }

        if (codeInSession == null) {
            throw new ValidateCodeException("验证码不存在");
        }
        /**
         * 判断验证码是否过期
         */
        if (codeInSession.isExpried()) {
            /**
             * 如果过期,把这个验证码从session删除
             */
            sessionStrategy.removeAttribute(request, ValidateCodeProcessor.SESSION_KEY_PREFIX + "SMS");
            throw new ValidateCodeException("验证码已过期");
        }

        if (!StringUtils.equals(codeInSession.getCode(), codeInRequest)) {
            throw new ValidateCodeException("验证码不匹配");
        }
        /**
         * 最后把验证码从session中删除
         */
        sessionStrategy.removeAttribute(request, ValidateCodeProcessor.SESSION_KEY_PREFIX + "SMS");
    }


}

重构合并这两个过滤器:

@Setter
@Getter
@Component("validateCodeFilter")
public class ValidateCodeFilter extends OncePerRequestFilter implements InitializingBean {

    /**
     * 身份认证失败处理器
     */
    @Autowired
    private AuthenticationFailureHandler authenticationFailureHandler;

    /**
     * 存放所有需要检验的验证码的url,application中配置的
     */
    private Map<String, ValidateCodeType> urlMap = new HashMap<>();

    /**
     * 系统中的校验码处理器
     */
    @Autowired
    private ValidateCodeProcessorHolder validateCodeProcessorHolder;

    /**
     * spring工具类可用来做url请求匹配
     */
    private AntPathMatcher pathMatcher = new AntPathMatcher();


    /**
     * 配置文件yunqing.security.........
     */
    @Autowired
    private SecurityProperties securityProperties;

    /**
     * 初始化要拦截的url配置信息
     * @throws ServletException
     */
    @Override
    public void afterPropertiesSet() throws ServletException {
        super.afterPropertiesSet();

        urlMap.put(SecurityConstants.DEFAULT_SIGN_IN_PROCESSING_URL_FORM, ValidateCodeType.IMAGE);
        addUrlToMap(securityProperties.getCode().getImage().getUrl(), ValidateCodeType.IMAGE);

        urlMap.put(SecurityConstants.DEFAULT_SIGN_IN_PROCESSING_URL_MOBILE, ValidateCodeType.SMS);
        addUrlToMap(securityProperties.getCode().getSms().getUrl(), ValidateCodeType.SMS);
    }

    /**
     * 讲系统中配置的需要校验验证码的URL根据校验的类型放入map
     *
     * @param urlString
     * @param type
     */
    protected void addUrlToMap(String urlString, ValidateCodeType type) {
        /**
         * 如果配置中的验证码不为空,按照逗号分隔加入到urlMap中
         */
        if (StringUtils.isNotBlank(urlString)) {
            String[] urls = StringUtils.splitByWholeSeparatorPreserveAllTokens(urlString, ",");
            for (String url : urls) {
                urlMap.put(url, type);
            }
        }
    }


    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
            throws ServletException, IOException {

        ValidateCodeType type = getValidateCodeType(request);
        if (type != null) {
            logger.info("校验请求(" + request.getRequestURI() + ")中的验证码,验证码类型" + type);
            try {
                validateCodeProcessorHolder.findValidateCodeProcessor(type)
                        .validate(new ServletWebRequest(request, response));
                logger.info("验证码校验通过");
            } catch (ValidateCodeException exception) {
                authenticationFailureHandler.onAuthenticationFailure(request, response, exception);
                return;
            }
        }

        chain.doFilter(request, response);

    }

    /**
     * 获取校验码的类型,如果当前请求不需要校验,则返回null
     *
     * @param request
     * @return
     */
    private ValidateCodeType getValidateCodeType(HttpServletRequest request) {
        ValidateCodeType result = null;
        if (!StringUtils.equalsIgnoreCase(request.getMethod(), "get")) {
            Set<String> urls = urlMap.keySet();
            for (String url : urls) {
                if (pathMatcher.match(url, request.getRequestURI())) {
                    result = urlMap.get(url);
                }
            }
        }
        return result;
    }


}

/**
 * 检验码处理器管理器
 */
@Component
public class ValidateCodeProcessorHolder {

    /**
     * 收集系统中所有 ValidateCodeProcessor 接口的实现类
     */
    @Autowired
    private Map<String, ValidateCodeProcessor> validateCodeProcessors;

    /**
     * @param type
     * @return
     */
    public ValidateCodeProcessor findValidateCodeProcessor(ValidateCodeType type) {
        return findValidateCodeProcessor(type.toString().toLowerCase());
    }

    /**
     * @param type
     * @return
     */
    public ValidateCodeProcessor findValidateCodeProcessor(String type) {
        String name = type.toLowerCase() + ValidateCodeProcessor.class.getSimpleName();
        ValidateCodeProcessor processor = validateCodeProcessors.get(name);
        if (processor == null) {
            throw new ValidateCodeException("验证码处理器" + name + "不存在");
        }
        return processor;
    }


}
/**
 * 抽象的图片验证码处理器
 *
 */
public abstract class AbstractValidateCodeProcessor<T extends ValidateCode> implements ValidateCodeProcessor {

    /**
     * 收集系统中所有的 {@link ValidateCodeGenerator} 接口的实现。
     */
    @Autowired
    private Map<String, ValidateCodeGenerator> validateCodeGenerators;

    @Autowired
    private ValidateCodeRepository validateCodeRepository;


    @Override
    public void create(ServletWebRequest request) throws Exception {
        T validateCode = generate(request);
        save(request, validateCode);
        send(request, validateCode);
    }

    /**
     * 生成校验码
     *
     * @param request
     * @return
     */
    @SuppressWarnings("unchecked")
    private T generate(ServletWebRequest request) {
        String type = getValidateCodeType(request).toString().toLowerCase();
        String generatorName = type + ValidateCodeGenerator.class.getSimpleName();
        ValidateCodeGenerator validateCodeGenerator = validateCodeGenerators.get(generatorName);
        if (validateCodeGenerator == null) {
            throw new ValidateCodeException("验证码生成器" + generatorName + "不存在");
        }
        return (T) validateCodeGenerator.generate(request);
    }

    /**
     * 保存校验码
     *
     * @param request
     * @param validateCode
     */
    private void save(ServletWebRequest request, T validateCode) {
        ValidateCode code = new ValidateCode(validateCode.getCode(), validateCode.getExpireTime());
        validateCodeRepository.save(request, code, getValidateCodeType(request));
    }

    /**
     * 发送校验码,由子类实现
     *
     * @param request
     * @param validateCode
     * @throws Exception
     */
    protected abstract void send(ServletWebRequest request, T validateCode) throws Exception;

    /**
     * 根据请求的url获取校验码的类型
     *
     * @param request
     * @return
     */
    private ValidateCodeType getValidateCodeType(ServletWebRequest request) {
        String type = StringUtils.substringBefore(getClass().getSimpleName(), "CodeProcessor");
        return ValidateCodeType.valueOf(type.toUpperCase());
    }

    @SuppressWarnings("unchecked")
    @Override
    public void validate(ServletWebRequest request) {

        ValidateCodeType codeType = getValidateCodeType(request);

        T codeInSession = (T) validateCodeRepository.get(request, codeType);

        String codeInRequest;
        try {
            codeInRequest = ServletRequestUtils.getStringParameter(request.getRequest(),
                    codeType.getParamNameOnValidate());
        } catch (ServletRequestBindingException e) {
            throw new ValidateCodeException("获取验证码的值失败");
        }

        if (StringUtils.isBlank(codeInRequest)) {
            throw new ValidateCodeException(codeType + "请填写验证码");
        }

        if (codeInSession == null) {
            throw new ValidateCodeException(codeType + "验证码不存在");
        }

        if (codeInSession.isExpried()) {
            validateCodeRepository.remove(request, codeType);
            throw new ValidateCodeException(codeType + "验证码已过期,请重新获取");
        }

        if (!StringUtils.equals(codeInSession.getCode(), codeInRequest)) {
            throw new ValidateCodeException(codeType + "验证码不正确");
        }

        validateCodeRepository.remove(request, codeType);

    }

}
上一篇 下一篇

猜你喜欢

热点阅读