Spring如何支持静态注入

2024-10-13  本文已影响0人  杭二

该文章为原创(转载请注明出处):Spring如何支持静态注入? - 简书 (jianshu.com)

真实业务场景

一些工具类或者业务上使用的比较通用的类,通过 bean 的方式去使用,使用起来会稍显冗余(对于开发功能来说并没有什么问题)。

需要达成的目的

会催生这一类需求,希望通过类似 @Autowired的方式

public class BizUtils {
    @Autowired
    private static CommonService commonService;
}

达成目的阻碍

由于Spring 注入是基于实例的,因此无法针对类的静态属性注入

方案思路

以实例的方式设置静态属性

既然Spring是基于实例的,那么我们可以手动构建一个实例来支持注入,例如

import org.springframework.beans.factory.InitializingBean;
public class BizUtilsInjector implements InitializingBean {
    @Autowired
    private CommonService commonService;
    
    @Override
    public void afterPropertiesSet() throws Exception {
        BizUtils.commonService = commonService;
    }
}

自动扫描来实现

以实例的方式设置静态属性 理论可行,但是每次需要手动维护
配合注解,通过自动的方式来实现上述操作

  1. 定义一个用于指定静态注入的注解 例如 @StaticInject
  2. 扫描应用项目下 所有带有 @StaticInject 属性的类
  3. 动态创建 XXXXInjector 类
  4. 自动构建XXXXInjector的bean,注入所需要注入的bean属性
  5. 再将Spring 自动注入 XXXXInjector中的bean 设置到类的属性

代码实现思路

1. 定义一个用于指定静态注入的注解

/**
 * @author fuhangbo.hanger.uhfun
 **/
@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface StaticInject {
    String value() default "";
}

2. 扫描应用项目下 所有带有 @StaticInject 属性的类

获取项目中应用包名
private String getApplicationBasePackage() {
    // 通过主类的包名获取基础包
    return applicationContext.getBeansWithAnnotation(SpringBootApplication.class)
        .values()
        .iterator()
        .next()
        .getClass().getPackage().getName();
}
使用Reflections扫描带有@StaticInject注解的字段,并构建属性注入信息
public List<StaticFieldInjectInfo> scanAndInjectStaticFields() {
    // 获取基础包名
    String basePackage = getApplicationBasePackage();
    // 使用Reflections扫描带有@StaticInject注解的字段
    Reflections reflections = new Reflections(new ConfigurationBuilder()
        .forPackage(basePackage)
        .addScanners(Scanners.FieldsAnnotated)
    );
    return reflections.getFieldsAnnotatedWith(StaticInject.class)
        .stream().filter(f -> Modifier.isStatic(f.getModifiers()))
        .collect(groupingBy(Field::getDeclaringClass)).entrySet()
        .stream().map(
            e -> new StaticFieldInjectInfo(determineBeanNameFieldMap(e.getKey(), e.getValue()), e.getKey()))
        .collect(toList());
}

3. 动态创建 XXXXInjector 类

动态构建Class
 // 使用 ByteBuddy 动态添加字段
DynamicType.Builder<?> builder = new ByteBuddy()
    .subclass(StaticFieldBeanHolder.class)
    .name(StaticFieldBeanHolder.class.getPackage().getName() + "."
        + staticFieldInjectInfo.getDeclaringSimpleName()
        + StaticFieldBeanHolder.class.getSimpleName());
// 创建需要注入的bean 类型的字段
for (Entry<String, Field> entry : staticFieldInjectInfo.getInjectedFieldBeanMap().entrySet()) {
    String beanName = entry.getKey();
    Field field = entry.getValue();
    builder = builder.defineProperty(beanName, field.getType());
}
// 生成class
Class<?> collectorClass = builder.make().load(StaticFieldBeanInjector.class.getClassLoader(),
    ClassLoadingStrategy.Default.INJECTION).getLoaded();

4. 自动构建XXXXInjector的bean,注入所需要注入的bean属性

definitionBuilder.addPropertyValue("staticFieldInjectInfo", staticFieldInjectInfo);
staticFieldInjectInfo.getInjectedFieldBeanMap().forEach((beanName, field) -> {
    // 指定依赖的bean
    definitionBuilder.addPropertyReference(beanName, beanName);
    definitionBuilder.addDependsOn(beanName);
});
String beanName = "staticFieldBeanCollectorFor" + staticFieldInjectInfo.getDeclaringSimpleName();
registry.registerBeanDefinition(beanName, definitionBuilder.getBeanDefinition());

再将Spring 自动注入 XXXXInjector中的bean 设置到类的属性

@Data
@Slf4j
public static class StaticFieldBeanHolder {
    public StaticFieldInjectInfo staticFieldInjectInfo;

    @PostConstruct
    @SneakyThrows
    public void doInject() {
        // 注入到static类的字段
        doWithFields(getClass(), field -> {
            Field targetField = staticFieldInjectInfo.getInjectedField(field.getName());
            makeAccessible(field);
            makeAccessible(targetField);
            targetField.set(null, field.get(this));
            log.info("@StaticInject {}.{} 注入bean {}",
                targetField.getDeclaringClass().getName(), field.getName(), field.getType());
        }, field -> staticFieldInjectInfo.containsField(field.getName()));
    }
}

代码使用 依赖库

  1. Reflections 类扫描
 <dependency>
    <groupId>org.reflections</groupId>
    <artifactId>reflections</artifactId>
</dependency>
  1. Byte Buddy 动态类生成
<dependency>
    <groupId>net.bytebuddy</groupId>
    <artifactId>byte-buddy</artifactId>
</dependency>

代码实现

import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import javax.annotation.PostConstruct;

import lombok.Data;
import lombok.Getter;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import net.bytebuddy.ByteBuddy;
import net.bytebuddy.dynamic.DynamicType;
import net.bytebuddy.dynamic.loading.ClassLoadingStrategy;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
import org.reflections.Reflections;
import org.reflections.scanners.Scanners;
import org.reflections.util.ConfigurationBuilder;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.stereotype.Component;

import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;
import static org.springframework.util.ReflectionUtils.doWithFields;
import static org.springframework.util.ReflectionUtils.makeAccessible;

/**
 * @author fuhangbo.hanger.uhfun
 **/
@Slf4j
@Component
public class StaticFieldBeanInjector implements
    BeanDefinitionRegistryPostProcessor, ApplicationContextAware {

    private static final String TRUE_INJECTOR_BEAN_NAME = StaticFieldBeanInjector.class.getSimpleName() + ".Last";

    private ApplicationContext applicationContext;

    private Map<String, Field> determineBeanNameFieldMap(Class<?> declaringClass, List<Field> injectedFields) {
        return injectedFields.stream().collect(toMap(field -> {
            StaticInject staticInject = AnnotationUtils.findAnnotation(field, StaticInject.class);
            String beanName = requireNonNull(staticInject).value();
            // 根据类型找到候选的bean名称
            String[] candidateBeanNames = applicationContext.getBeanNamesForType(field.getType());
            if (StringUtils.isBlank(beanName)) {
                if (ArrayUtils.isEmpty(candidateBeanNames)) {
                    throw new NoSuchBeanDefinitionException(field.getType(),
                        String.format("@StaticInject %s.%s 注入bean类型: %s,找不到bean", declaringClass.getSimpleName(),
                            field.getName(), field.getType().getSimpleName()));
                } else {
                    if (candidateBeanNames.length > 1) {
                        throw new NoUniqueBeanDefinitionException(field.getType(), candidateBeanNames.length,
                            String.format("@StaticInject %s.%s 注入bean类型: %s,存在多个bean: %s",
                                declaringClass.getSimpleName(),
                                field.getName(), field.getType().getSimpleName(),
                                String.join(",", candidateBeanNames)));
                    }
                    return candidateBeanNames[0];
                }
            } else if (!ArrayUtils.contains(candidateBeanNames, beanName)) {
                throw new NoSuchBeanDefinitionException(beanName,
                    String.format("@StaticInject %s.%s 注入bean类型: %s,找不到指定的bean: %s",
                        declaringClass.getSimpleName(),
                        field.getName(), field.getType().getSimpleName(), beanName));
            }
            return beanName;
        }, identity(), (k1, k2) -> k1, LinkedHashMap::new));
    }

    public List<StaticFieldInjectInfo> scanAndInjectStaticFields() {
        // 获取基础包名
        String basePackage = getApplicationBasePackage();
        // 使用Reflections扫描带有@StaticInject注解的字段
        Reflections reflections = new Reflections(new ConfigurationBuilder()
            .forPackage(basePackage)
            .addScanners(Scanners.FieldsAnnotated)
        );
        return reflections.getFieldsAnnotatedWith(StaticInject.class)
            .stream().filter(f -> Modifier.isStatic(f.getModifiers()))
            .collect(groupingBy(Field::getDeclaringClass)).entrySet()
            .stream().map(
                e -> new StaticFieldInjectInfo(determineBeanNameFieldMap(e.getKey(), e.getValue()), e.getKey()))
            .collect(toList());
    }

    private String getApplicationBasePackage() {
        // 通过主类的包名获取基础包
        return applicationContext.getBeansWithAnnotation(SpringBootApplication.class)
            .values()
            .iterator()
            .next()
            .getClass().getPackage().getName();
    }

    @Override
    public void postProcessBeanDefinitionRegistry(@NotNull BeanDefinitionRegistry registry) throws BeansException {
        if (registry.containsBeanDefinition(TRUE_INJECTOR_BEAN_NAME)) {
            List<StaticFieldInjectInfo> injectInfos = scanAndInjectStaticFields();
            for (StaticFieldInjectInfo staticFieldInjectInfo : injectInfos) {
                // 使用 ByteBuddy 动态添加字段
                DynamicType.Builder<?> builder = new ByteBuddy()
                    .subclass(StaticFieldBeanHolder.class)
                    .name(StaticFieldBeanHolder.class.getPackage().getName() + "."
                        + staticFieldInjectInfo.getDeclaringSimpleName()
                        + StaticFieldBeanHolder.class.getSimpleName());
                // 创建需要注入的bean 类型的字段
                for (Entry<String, Field> entry : staticFieldInjectInfo.getInjectedFieldBeanMap().entrySet()) {
                    String beanName = entry.getKey();
                    Field field = entry.getValue();
                    builder = builder.defineProperty(beanName, field.getType());
                }
                // 生成class
                Class<?> collectorClass = builder.make().load(StaticFieldBeanInjector.class.getClassLoader(),
                    ClassLoadingStrategy.Default.INJECTION).getLoaded();
                BeanDefinitionBuilder definitionBuilder = BeanDefinitionBuilder
                    .genericBeanDefinition(collectorClass);

                definitionBuilder.addPropertyValue("staticFieldInjectInfo", staticFieldInjectInfo);
                staticFieldInjectInfo.getInjectedFieldBeanMap().forEach((beanName, field) -> {
                    // 指定依赖的bean
                    definitionBuilder.addPropertyReference(beanName, beanName);
                    definitionBuilder.addDependsOn(beanName);
                });

                String beanName = "staticFieldBeanCollectorFor" + staticFieldInjectInfo.getDeclaringSimpleName();
                registry.registerBeanDefinition(beanName, definitionBuilder.getBeanDefinition());
            }
        }

        if (!registry.containsBeanDefinition(TRUE_INJECTOR_BEAN_NAME)) {
            // 由于部分bean是在其他BeanDefinitionRegistryPostProcessor中注册的,
            // 例如 MapperScannerConfigurer:Mybatis扫描注册生成mapper代理的bean
            // 由于其他BeanDefinitionRegistryPostProcessor使用order排序 如果是默认的顺序都是Ordered.LOWEST_PRECEDENCE 排序不生效
            // 因此一个取巧的方式 再加入一个 StaticFieldBeanInjector
            // 这样可做到在所有的其他BeanDefinitionRegistryPostProcessor 执行完后再执行
            BeanDefinitionBuilder injectorRootDefinitionBuilder = BeanDefinitionBuilder
                .genericBeanDefinition(StaticFieldBeanInjector.class);
            registry.registerBeanDefinition(TRUE_INJECTOR_BEAN_NAME,
                injectorRootDefinitionBuilder.getBeanDefinition());
        }
    }

    @Override
    public void setApplicationContext(@NotNull ApplicationContext applicationContext) throws BeansException {
        this.applicationContext = applicationContext;
    }

    @Override
    public void postProcessBeanFactory(@NotNull ConfigurableListableBeanFactory beanFactory) throws BeansException {
    }

    @Data
    @Slf4j
    public static class StaticFieldBeanHolder {
        public StaticFieldInjectInfo staticFieldInjectInfo;

        @PostConstruct
        @SneakyThrows
        public void doInject() {
            // 注入到static类的字段
            doWithFields(getClass(), field -> {
                Field targetField = staticFieldInjectInfo.getInjectedField(field.getName());
                makeAccessible(field);
                makeAccessible(targetField);
                targetField.set(null, field.get(this));
                log.info("@StaticInject {}.{} 注入bean {}",
                    targetField.getDeclaringClass().getName(), field.getName(), field.getType());
            }, field -> staticFieldInjectInfo.containsField(field.getName()));
        }
    }

    @Getter
    public static class StaticFieldInjectInfo {
        private final Map<String, Field> injectedFieldBeanMap;
        private final Class<?> declaringClass;

        public StaticFieldInjectInfo(Map<String, Field> injectedFieldBeanMap, Class<?> declaringClass) {
            this.declaringClass = declaringClass;
            this.injectedFieldBeanMap = injectedFieldBeanMap;
        }

        public String getDeclaringSimpleName() {
            return getDeclaringClass().getSimpleName();
        }

        public boolean containsField(String name) {
            return injectedFieldBeanMap.containsKey(name);
        }

        public Field getInjectedField(String name) {
            return injectedFieldBeanMap.get(name);
        }
    }
}

方案思路2(简易版)

第一种方案,过于绕,直接扫描注入更快

  1. 扫描
  2. springBean 实例化完成后 反射注入

代码实现2

import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;

import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
import org.reflections.Reflections;
import org.reflections.scanners.Scanners;
import org.reflections.util.ConfigurationBuilder;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.beans.factory.SmartInitializingSingleton;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.stereotype.Component;

import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.groupingBy;
import static java.util.stream.Collectors.mapping;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;
import static org.springframework.util.ReflectionUtils.makeAccessible;

/**
 * @author fuhangbo.hanger.uhfun
 **/
@Slf4j
@Component
public class StaticFieldBeanInjector implements ApplicationContextAware, SmartInitializingSingleton {

    private ApplicationContext applicationContext;

    private Map<String, Field> determineBeanNameFieldMap(Class<?> declaringClass, List<Field> injectedFields) {
        return injectedFields.stream().collect(toMap(field -> {
            StaticInject staticInject = AnnotationUtils.findAnnotation(field, StaticInject.class);
            String beanName = requireNonNull(staticInject).value();
            // 根据类型找到候选的bean名称
            String[] candidateBeanNames = applicationContext.getBeanNamesForType(field.getType());
            if (StringUtils.isBlank(beanName)) {
                if (ArrayUtils.isEmpty(candidateBeanNames)) {
                    throw new NoSuchBeanDefinitionException(field.getType(),
                        String.format("@StaticInject %s.%s 注入bean类型: %s,找不到bean", declaringClass.getSimpleName(),
                            field.getName(), field.getType().getSimpleName()));
                } else {
                    if (candidateBeanNames.length > 1) {
                        throw new NoUniqueBeanDefinitionException(field.getType(), candidateBeanNames.length,
                            String.format("@StaticInject %s.%s 注入bean类型: %s,存在多个bean: %s",
                                declaringClass.getSimpleName(),
                                field.getName(), field.getType().getSimpleName(),
                                String.join(",", candidateBeanNames)));
                    }
                    return candidateBeanNames[0];
                }
            } else if (!ArrayUtils.contains(candidateBeanNames, beanName)) {
                throw new NoSuchBeanDefinitionException(beanName,
                    String.format("@StaticInject %s.%s 注入bean类型: %s,找不到指定的bean: %s",
                        declaringClass.getSimpleName(),
                        field.getName(), field.getType().getSimpleName(), beanName));
            }
            return beanName;
        }, identity(), (k1, k2) -> k1, LinkedHashMap::new));
    }

    public Map<String, List<Field>> scanAndInjectStaticFields() {
        // 获取基础包名
        String basePackage = getApplicationBasePackage();
        // 使用Reflections扫描带有@StaticInject注解的字段
        Reflections reflections = new Reflections(new ConfigurationBuilder()
            .forPackage(basePackage)
            .addScanners(Scanners.FieldsAnnotated)
        );
        return reflections.getFieldsAnnotatedWith(StaticInject.class)
            .stream().filter(f -> Modifier.isStatic(f.getModifiers()))
            .collect(groupingBy(Field::getDeclaringClass)).entrySet()
            .stream()
            .map(e -> determineBeanNameFieldMap(e.getKey(), e.getValue()))
            .map(Map::entrySet)
            .flatMap(Collection::stream)
            .collect(groupingBy(Entry::getKey, mapping(Entry::getValue, toList())));
    }

    private String getApplicationBasePackage() {
        // 通过主类的包名获取基础包
        return applicationContext.getBeansWithAnnotation(SpringBootApplication.class)
            .values()
            .iterator()
            .next()
            .getClass().getPackage().getName();
    }

    @Override
    public void setApplicationContext(@NotNull ApplicationContext applicationContext) throws BeansException {
        this.applicationContext = applicationContext;
    }

    @SneakyThrows
    @Override
    public void afterSingletonsInstantiated() {
        Map<String, List<Field>> injectStaticFieldsMap = scanAndInjectStaticFields();
        for (Entry<String, List<Field>> entry : injectStaticFieldsMap.entrySet()) {
            String beanName = entry.getKey();
            List<Field> fields = entry.getValue();
            Object injectedBean = applicationContext.getBean(beanName);
            for (Field targetField : fields) {
                makeAccessible(targetField);
                targetField.set(null, injectedBean);
                log.info("@StaticInject {}.{} 注入bean {}",
                    targetField.getDeclaringClass().getName(), targetField.getName(), beanName);
            }
        }

    }
}

该文章为原创(转载请注明出处):Spring如何支持静态注入? - 简书 (jianshu.com)

上一篇 下一篇

猜你喜欢

热点阅读