ThreadLocal使用和原理

2023-01-17  本文已影响0人  定金喜

1.使用场景

场景1:全局获取request,response

ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = requestAttributes.getRequest();
HttpServletResponse response = requestAttributes.getResponse();

可以查看RequestContextHolder源码:

package org.springframework.web.context.request;

import javax.faces.context.FacesContext;
import org.springframework.core.NamedInheritableThreadLocal;
import org.springframework.core.NamedThreadLocal;
import org.springframework.lang.Nullable;
import org.springframework.util.ClassUtils;

public abstract class RequestContextHolder {
    private static final boolean jsfPresent = ClassUtils.isPresent("javax.faces.context.FacesContext", RequestContextHolder.class.getClassLoader());
    private static final ThreadLocal<RequestAttributes> requestAttributesHolder = new NamedThreadLocal("Request attributes");
    private static final ThreadLocal<RequestAttributes> inheritableRequestAttributesHolder = new NamedInheritableThreadLocal("Request context");

    public RequestContextHolder() {
    }

    public static void resetRequestAttributes() {
        requestAttributesHolder.remove();
        inheritableRequestAttributesHolder.remove();
    }

    public static void setRequestAttributes(@Nullable RequestAttributes attributes) {
        setRequestAttributes(attributes, false);
    }

    public static void setRequestAttributes(@Nullable RequestAttributes attributes, boolean inheritable) {
        if (attributes == null) {
            resetRequestAttributes();
        } else if (inheritable) {
            inheritableRequestAttributesHolder.set(attributes);
            requestAttributesHolder.remove();
        } else {
            requestAttributesHolder.set(attributes);
            inheritableRequestAttributesHolder.remove();
        }

    }

    @Nullable
    public static RequestAttributes getRequestAttributes() {
        RequestAttributes attributes = (RequestAttributes)requestAttributesHolder.get();
        if (attributes == null) {
            attributes = (RequestAttributes)inheritableRequestAttributesHolder.get();
        }

        return attributes;
    }

    public static RequestAttributes currentRequestAttributes() throws IllegalStateException {
        RequestAttributes attributes = getRequestAttributes();
        if (attributes == null) {
            if (jsfPresent) {
                attributes = RequestContextHolder.FacesRequestAttributesFactory.getFacesRequestAttributes();
            }

            if (attributes == null) {
                throw new IllegalStateException("No thread-bound request found: Are you referring to request attributes outside of an actual web request, or processing a request outside of the originally receiving thread? If you are actually operating within a web request and still receive this message, your code is probably running outside of DispatcherServlet: In this case, use RequestContextListener or RequestContextFilter to expose the current request.");
            }
        }

        return attributes;
    }

    private static class FacesRequestAttributesFactory {
        private FacesRequestAttributesFactory() {
        }

        @Nullable
        public static RequestAttributes getFacesRequestAttributes() {
            FacesContext facesContext = FacesContext.getCurrentInstance();
            return facesContext != null ? new FacesRequestAttributes(facesContext) : null;
        }
    }
}

场景2:日志上下文信息

<appender name="ERROR-FILE"
              class="ch.qos.logback.core.rolling.RollingFileAppender">
        <encoder>
            <pattern><![CDATA[
[%d{yyyy-MM-dd HH:mm:ss}] [traceid=%X{traceId},uid=%X{uId},corpid=%X{domId},uri=%X{uri}]  %-5level %logger{35} - %m%n
            ]]></pattern>
            <charset>utf8</charset>
        </encoder>
        <file>${LOG_PATH}/error.log</file>
        <rollingPolicy class="ch.qos.logback.core.rolling.TimeBasedRollingPolicy">
            <fileNamePattern>${LOG_PATH}/error.%d{yyyy-MM-dd}-%i.log</fileNamePattern>
            <MaxHistory>5</MaxHistory>
            <timeBasedFileNamingAndTriggeringPolicy
                    class="ch.qos.logback.core.rolling.SizeAndTimeBasedFNATP">
                <maxFileSize>500MB</maxFileSize>
            </timeBasedFileNamingAndTriggeringPolicy>
        </rollingPolicy>
        <filter class="ch.qos.logback.classic.filter.LevelFilter">
            <level>ERROR</level>
            <onMatch>ACCEPT</onMatch>
            <onMismatch>DENY</onMismatch>
        </filter>
    </appender>

一般是通过org.slf4j.MDC 输出日志时将信息输出

输出日志例子:

image

代码:

package com.alibaba.atlas.security;

import org.jboss.logging.MDC;
import org.springframework.util.AntPathMatcher;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;

/**
 * @author:dxc
 */
public class LoggerFilter implements Filter {

    private static final String REFERER_KEY = "referer";
    private static final String USER_AGENT_KEY = "ua";
    private static final String UID_KEY = "uId";
    private static final String DOM_KEY = "domId";
    private static final String URI_KEY = "uri";
    public static final String TRACE_ID_KEY = "traceId";

    private static final String USER_AGENT_NAME = "User-Agent";
    private static final String REFERER_NAME = "Referer";

    /**
     * health health2
     */
    private static final String HEALTH = "/health";

    private static final AntPathMatcher ANT_PATH_MATCHER = new AntPathMatcher();

    private static List<String> excludeUrlList = new ArrayList<>();

    public LoggerFilter(String excludeUrl) {
        if (excludeUrl == null) {
            return;
        }
        String[] split = excludeUrl.split(",");
        excludeUrlList = Arrays.asList(split);
    }

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain) throws IOException, ServletException {
        HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
        String requestURI = httpServletRequest.getRequestURI();

        if (isExclude(requestURI)) {
            chain.doFilter(servletRequest, servletResponse);
            return;
        }

        String userAgent = httpServletRequest.getHeader(USER_AGENT_NAME);
        String referer = httpServletRequest.getHeader(REFERER_NAME);
        try {
            MDC.put(REFERER_KEY, referer);
            MDC.put(DOM_KEY, "demo");
            MDC.put(UID_KEY, "dxc");
            MDC.put(USER_AGENT_KEY, userAgent);
            MDC.put(URI_KEY, requestURI);
            MDC.put(TRACE_ID_KEY, UUID.randomUUID().toString().replace("-", ""));

            chain.doFilter(servletRequest, servletResponse);
        } finally {
            MDC.clear();
        }
    }

    @Override
    public void destroy() {

    }

    private static boolean isExclude(String uri) {
        return matchUrI(excludeUrlList, uri);
    }

    private static boolean matchUrI(List<String> list, String uri) {
        for (String pattern : list) {
            if (ANT_PATH_MATCHER.match(pattern, uri)) {
                return true;
            }
        }
        return false;
    }
}

MDC源码:

package ch.qos.logback.classic.util;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.slf4j.spi.MDCAdapter;

public class LogbackMDCAdapter implements MDCAdapter {
    final ThreadLocal<Map<String, String>> copyOnThreadLocal = new ThreadLocal();
    private static final int WRITE_OPERATION = 1;
    private static final int MAP_COPY_OPERATION = 2;
    final ThreadLocal<Integer> lastOperation = new ThreadLocal();

    public LogbackMDCAdapter() {
    }

    private Integer getAndSetLastOperation(int op) {
        Integer lastOp = (Integer)this.lastOperation.get();
        this.lastOperation.set(op);
        return lastOp;
    }

    private boolean wasLastOpReadOrNull(Integer lastOp) {
        return lastOp == null || lastOp == 2;
    }

    private Map<String, String> duplicateAndInsertNewMap(Map<String, String> oldMap) {
        Map<String, String> newMap = Collections.synchronizedMap(new HashMap());
        if (oldMap != null) {
            synchronized(oldMap) {
                newMap.putAll(oldMap);
            }
        }

        this.copyOnThreadLocal.set(newMap);
        return newMap;
    }

    public void put(String key, String val) throws IllegalArgumentException {
        if (key == null) {
            throw new IllegalArgumentException("key cannot be null");
        } else {
            Map<String, String> oldMap = (Map)this.copyOnThreadLocal.get();
            Integer lastOp = this.getAndSetLastOperation(1);
            if (!this.wasLastOpReadOrNull(lastOp) && oldMap != null) {
                oldMap.put(key, val);
            } else {
                Map<String, String> newMap = this.duplicateAndInsertNewMap(oldMap);
                newMap.put(key, val);
            }

        }
    }

    public void remove(String key) {
        if (key != null) {
            Map<String, String> oldMap = (Map)this.copyOnThreadLocal.get();
            if (oldMap != null) {
                Integer lastOp = this.getAndSetLastOperation(1);
                if (this.wasLastOpReadOrNull(lastOp)) {
                    Map<String, String> newMap = this.duplicateAndInsertNewMap(oldMap);
                    newMap.remove(key);
                } else {
                    oldMap.remove(key);
                }

            }
        }
    }

    public void clear() {
        this.lastOperation.set(1);
        this.copyOnThreadLocal.remove();
    }

    public String get(String key) {
        Map<String, String> map = (Map)this.copyOnThreadLocal.get();
        return map != null && key != null ? (String)map.get(key) : null;
    }

    public Map<String, String> getPropertyMap() {
        this.lastOperation.set(2);
        return (Map)this.copyOnThreadLocal.get();
    }

    public Set<String> getKeys() {
        Map<String, String> map = this.getPropertyMap();
        return map != null ? map.keySet() : null;
    }

    public Map<String, String> getCopyOfContextMap() {
        Map<String, String> hashMap = (Map)this.copyOnThreadLocal.get();
        return hashMap == null ? null : new HashMap(hashMap);
    }

    public void setContextMap(Map<String, String> contextMap) {
        this.lastOperation.set(1);
        Map<String, String> newMap = Collections.synchronizedMap(new HashMap());
        newMap.putAll(contextMap);
        this.copyOnThreadLocal.set(newMap);
    }
}

场景三:数据库事务

AbstractPlatformTransactionManager

public abstract class TransactionSynchronizationManager {

    private static final Log logger = LogFactory.getLog(TransactionSynchronizationManager.class);

    private static final ThreadLocal<Map<Object, Object>> resources =
            new NamedThreadLocal<>("Transactional resources");

    private static final ThreadLocal<Set<TransactionSynchronization>> synchronizations =
            new NamedThreadLocal<>("Transaction synchronizations");

    private static final ThreadLocal<String> currentTransactionName =
            new NamedThreadLocal<>("Current transaction name");

    private static final ThreadLocal<Boolean> currentTransactionReadOnly =
            new NamedThreadLocal<>("Current transaction read-only status");

    private static final ThreadLocal<Integer> currentTransactionIsolationLevel =
            new NamedThreadLocal<>("Current transaction isolation level");

    private static final ThreadLocal<Boolean> actualTransactionActive =
            new NamedThreadLocal<>("Actual transaction active");

数据库事务底层也是用ThreadLocal实现,所以如果异步线程去处理一个特殊请求,那么异步任务的事务是失效的,需要在异步任务中自己去控制新的事务,所以同一个事务一般是在同一个线程中完成,不能跨多个线程。

场景四:权限上下文信息

表单填报UserService

@ApplicationScoped
public class UserService extends BaseService {
    private static final Logger logger = LoggerFactory.getLogger(UserService.class);

    @Inject
    SecurityIdentity identity;

    @Inject
    UserProfileCache userProfileCache;

    @Inject
    @RestClient
    BIService bi;

    @ConfigProperty(name = "custom.function.access.folder")
    Optional<Boolean> folderAccessOpt;

    @ConfigProperty(name = "custom.universe.domains")
    Optional<List<String>> universeDomainsOpt;

实现原理:

@RequestScoped
public class SecurityIdentityProxy implements SecurityIdentity {
    @Inject
    SecurityIdentityAssociation association;

    public SecurityIdentityProxy() {
    }

    public Principal getPrincipal() {
        return this.association.getIdentity().getPrincipal();
    }

使用Threadlocal实现

package com.alibaba.survey.server.model;

import lombok.Getter;

@Getter
public class UnifiedLoginUser {

    private static ThreadLocal<User> threadLocalUser = new InheritableThreadLocal<>();

    public static User getCurrentUser() {
        return threadLocalUser.get();
    }

    public static void setThreadLocalUser(User user){
        threadLocalUser.set(user);
    }
}
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();

final class ThreadLocalSecurityContextHolderStrategy implements SecurityContextHolderStrategy {
    private static final ThreadLocal<SecurityContext> contextHolder = new ThreadLocal();

    ThreadLocalSecurityContextHolderStrategy() {
    }

    public void clearContext() {
        contextHolder.remove();
    }

    public SecurityContext getContext() {
        SecurityContext ctx = (SecurityContext)contextHolder.get();
        if (ctx == null) {
            ctx = this.createEmptyContext();
            contextHolder.set(ctx);
        }

        return ctx;
    }

    public void setContext(SecurityContext context) {
        Assert.notNull(context, "Only non-null SecurityContext instances are permitted");
        contextHolder.set(context);
    }

    public SecurityContext createEmptyContext() {
        return new SecurityContextImpl();
    }
}

项目的登录信息一般会放在Threadlocal,方面所有的地方去获取,相当于一个线程内的缓存。

2.底层原理

Thread

ThreadLocal.ThreadLocalMap threadLocals = null;

ThreadLocalMap

static class ThreadLocalMap {

        /**
         * The entries in this hash map extend WeakReference, using
         * its main ref field as the key (which is always a
         * ThreadLocal object).  Note that null keys (i.e. entry.get()
         * == null) mean that the key is no longer referenced, so the
         * entry can be expunged from table.  Such entries are referred to
         * as "stale entries" in the code that follows.
         */
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        /**
         * The initial capacity -- MUST be a power of two.
         */
        private static final int INITIAL_CAPACITY = 16;

        /**
         * The table, resized as necessary.
         * table.length MUST always be a power of two.
         */
        private Entry[] table;

        /**
         * The number of entries in the table.
         */
        private int size = 0;

Threadlocal #set

public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
//get
public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

步骤:

1.获取当前线程的引用;

2.获取当前线程对象的threadLocals字段(类似于一个map);

3.不为null则以Threadlocal对象为key,需要设置的值为value;

4.为null则新建map,作为当前线程的threadLocals变量

3.使用不当遇到的问题

疑问1:entry为什么用弱引用

static class ThreadLocalMap {

        /**
         * The entries in this hash map extend WeakReference, using
         * its main ref field as the key (which is always a
         * ThreadLocal object).  Note that null keys (i.e. entry.get()
         * == null) mean that the key is no longer referenced, so the
         * entry can be expunged from table.  Such entries are referred to
         * as "stale entries" in the code that follows.
         */
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        /**
         * The initial capacity -- MUST be a power of two.
         */
        private static final int INITIAL_CAPACITY = 16;

        /**
         * The table, resized as necessary.
         * table.length MUST always be a power of two.
         */
        private Entry[] table;

先用例子分析下java中的引用类型 强引用 >软引用>弱引用>虚引用

package com.alibaba.fetchdata.service.repository;

import java.lang.ref.SoftReference;
import java.lang.ref.WeakReference;

public class ReferenceTest {

    public static void main(String[] args) {

        //强引用
        Object strongObj = new Object();
        System.gc();
        System.out.println("强引用gc后:"+strongObj);

        //软引用
        SoftReference<Object> softReference = new SoftReference<>(new Object());
        System.out.println("软引用gc前:"+softReference.get());
        System.gc();
        System.out.println("软引用gc后:"+softReference.get());
        try {
            Object[] softArr = new Object[1024*1024*1024];
        }catch (Throwable throwable){
            System.out.println(throwable);
        }
        System.out.println("软引用内存不足后:"+softReference.get());

        //强引用+软引用
//        Object softObj = new Object();     //有强引用指向,不会回收
//        SoftReference<Object> softReference = new SoftReference<>(softObj);
//        System.out.println("软引用gc前:"+softReference.get());
//        System.gc();
//        System.out.println("软引用gc后:"+softReference.get());
//        try {
//            Object[] softArr = new Object[1024*1024*1024];
//        }catch (Throwable throwable){
//            System.out.println(throwable);
//        }
//        System.out.println("软引用内存不足后:"+softReference.get());

        //弱引用
        WeakReference<Object> objectWeakReference = new WeakReference<>(new Object());
        System.out.println("弱引用gc前:"+objectWeakReference.get());
        System.gc();
        System.out.println("弱引用gc后:"+objectWeakReference.get());

        //弱引用
//        Object weakObject = new Object();
//        WeakReference<Object> objectWeakReference = new WeakReference<>(weakObject);
//        System.out.println("弱引用gc前:"+objectWeakReference.get());
//        System.gc();
//        System.out.println("弱引用gc后:"+objectWeakReference.get());

    }
}

分析Threadlocal:

Threadlocal<Object> threadlocal = new Threadlocal<Object>();

Thread -> ThreadLocal.ThreadLocalMap -> Entry[] -> Entry -> key(threadLocal对象)和value

所以local会被强引用threadlocal和Thread弱引用key所引用

当threadlocal不可访问 threadlocal=null时,垃圾回收器收集时,因为key只有弱引用关联,所以对应的local对象会被回收;如果不设计成弱引用,则不会被回收器回收。

参考文章:https://zhuanlan.zhihu.com/p/304240519

疑问2:到底需要不需要remove

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class ThreadLocalDemo {

    private static ExecutorService executorService = Executors.newFixedThreadPool(3);
    private static ThreadLocal<String> threadLocal = new ThreadLocal();

    public static void main(String[] args) {
        for (int i = 0; i < 10; i++) {
            int j = i + 1;
            executorService.execute(new Runnable() {
                @Override
                public void run() {
                    System.out.println(Thread.currentThread().getName() + "---设值前threadLocal=" + threadLocal.get());
                    threadLocal.set(Thread.currentThread().getName() + " and 第" + j + "个请求");
                    System.out.println(Thread.currentThread().getName() + "---设值后threadLocal=" + threadLocal.get());
                }
            });
        }
        executorService.shutdown();
    }
}

上述代码有无问题:

问题原因:线程池复用技术,核心线程基本不会销毁

解决方案:try finally 清理

疑问2:代码中需要异步执行的代码怎么处理?

解决方式1:执行异步任务前保存上下文信息,然后执行异步任务时初始化上下文

以某个项目批量更新为例,需要异步执行更新逻辑,源码如下:

public void doBatchUpdate(BatchUpdateDataDTO batchUpdateData) {

        String tableName = batchUpdateData.getTableName();

        Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
        UserInfo userInfo = new UserInfo(authentication);

        String taskId = UUID.randomUUID().toString().replace("-", "");
        TaskDAO taskDAO = new TaskDAO(taskId, TaskTypeEnum.BATCH_MODIFY_PT, tableName, TaskStatusEnum.RUNNING, userInfo.getUserId());
        TaskRepository.insertTask(taskDAO);

        batchUpdateData.setTaskId(taskId);

        batchUpdateData.setAuthentication(authentication);

        //异步任务,新的线程
        disruptorTaskService.sendNotify(EventTypeEnum.BATCH_UPDATE_TABLE, batchUpdateData);
    }

异步处理逻辑:

image

缺点:需要手动设置,当需要的Threadlocal变量比较多时,将是灾难

解决方式2:使用InheritableThreadLocal,原理其实和方式1类似

Thread 初始化逻辑

private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize, AccessControlContext acc,
                      boolean inheritThreadLocals) {
        if (name == null) {
            throw new NullPointerException("name cannot be null");
        }

        this.name = name;

        Thread parent = currentThread();
        SecurityManager security = System.getSecurityManager();
        if (g == null) {
            /* Determine if it's an applet or not */

            /* If there is a security manager, ask the security manager
               what to do. */
            if (security != null) {
                g = security.getThreadGroup();
            }

            /* If the security doesn't have a strong opinion of the matter
               use the parent thread group. */
            if (g == null) {
                g = parent.getThreadGroup();
            }
        }

        /* checkAccess regardless of whether or not threadgroup is
           explicitly passed in. */
        g.checkAccess();

        /*
         * Do we have the required permissions?
         */
        if (security != null) {
            if (isCCLOverridden(getClass())) {
                security.checkPermission(SUBCLASS_IMPLEMENTATION_PERMISSION);
            }
        }

        g.addUnstarted();

        this.group = g;
        this.daemon = parent.isDaemon();
        this.priority = parent.getPriority();
        if (security == null || isCCLOverridden(parent.getClass()))
            this.contextClassLoader = parent.getContextClassLoader();
        else
            this.contextClassLoader = parent.contextClassLoader;
        this.inheritedAccessControlContext =
                acc != null ? acc : AccessController.getContext();
        this.target = target;
        setPriority(priority);
        if (inheritThreadLocals && parent.inheritableThreadLocals != null)
            this.inheritableThreadLocals =
                ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
        /* Stash the specified stack size in case the VM cares */
        this.stackSize = stackSize;

        /* Set thread ID */
        tid = nextThreadID();
    }

在线程中开启新的线程,新的线程为子线程,另一个为父线程,子线程初始化时,会将父线程的inheritableThreadLocals数据拷贝过来,这样子线程就有了父线程的上下文信息了,和解决方式1异曲同工之妙,但是框架会自动给你设置好,避免自己手动设置造成的各种问题,spring-security也提供了这种结构:

 class InheritableThreadLocalSecurityContextHolderStrategy implements SecurityContextHolderStrategy {
    private static final ThreadLocal<SecurityContext> contextHolder = new InheritableThreadLocal();

    InheritableThreadLocalSecurityContextHolderStrategy() {
    }

    public void clearContext() {
        contextHolder.remove();
    }

    public SecurityContext getContext() {
        SecurityContext ctx = (SecurityContext)contextHolder.get();
        if (ctx == null) {
            ctx = this.createEmptyContext();
            contextHolder.set(ctx);
        }

        return ctx;
    }
上一篇 下一篇

猜你喜欢

热点阅读