ThreadLocal使用和原理
1.使用场景
场景1:全局获取request,response
ServletRequestAttributes requestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
HttpServletRequest request = requestAttributes.getRequest();
HttpServletResponse response = requestAttributes.getResponse();
可以查看RequestContextHolder源码:
package org.springframework.web.context.request;
import javax.faces.context.FacesContext;
import org.springframework.core.NamedInheritableThreadLocal;
import org.springframework.core.NamedThreadLocal;
import org.springframework.lang.Nullable;
import org.springframework.util.ClassUtils;
public abstract class RequestContextHolder {
private static final boolean jsfPresent = ClassUtils.isPresent("javax.faces.context.FacesContext", RequestContextHolder.class.getClassLoader());
private static final ThreadLocal<RequestAttributes> requestAttributesHolder = new NamedThreadLocal("Request attributes");
private static final ThreadLocal<RequestAttributes> inheritableRequestAttributesHolder = new NamedInheritableThreadLocal("Request context");
public RequestContextHolder() {
}
public static void resetRequestAttributes() {
requestAttributesHolder.remove();
inheritableRequestAttributesHolder.remove();
}
public static void setRequestAttributes(@Nullable RequestAttributes attributes) {
setRequestAttributes(attributes, false);
}
public static void setRequestAttributes(@Nullable RequestAttributes attributes, boolean inheritable) {
if (attributes == null) {
resetRequestAttributes();
} else if (inheritable) {
inheritableRequestAttributesHolder.set(attributes);
requestAttributesHolder.remove();
} else {
requestAttributesHolder.set(attributes);
inheritableRequestAttributesHolder.remove();
}
}
@Nullable
public static RequestAttributes getRequestAttributes() {
RequestAttributes attributes = (RequestAttributes)requestAttributesHolder.get();
if (attributes == null) {
attributes = (RequestAttributes)inheritableRequestAttributesHolder.get();
}
return attributes;
}
public static RequestAttributes currentRequestAttributes() throws IllegalStateException {
RequestAttributes attributes = getRequestAttributes();
if (attributes == null) {
if (jsfPresent) {
attributes = RequestContextHolder.FacesRequestAttributesFactory.getFacesRequestAttributes();
}
if (attributes == null) {
throw new IllegalStateException("No thread-bound request found: Are you referring to request attributes outside of an actual web request, or processing a request outside of the originally receiving thread? If you are actually operating within a web request and still receive this message, your code is probably running outside of DispatcherServlet: In this case, use RequestContextListener or RequestContextFilter to expose the current request.");
}
}
return attributes;
}
private static class FacesRequestAttributesFactory {
private FacesRequestAttributesFactory() {
}
@Nullable
public static RequestAttributes getFacesRequestAttributes() {
FacesContext facesContext = FacesContext.getCurrentInstance();
return facesContext != null ? new FacesRequestAttributes(facesContext) : null;
}
}
}
场景2:日志上下文信息
<appender name="ERROR-FILE"
class="ch.qos.logback.core.rolling.RollingFileAppender">
<encoder>
<pattern><![CDATA[
[%d{yyyy-MM-dd HH:mm:ss}] [traceid=%X{traceId},uid=%X{uId},corpid=%X{domId},uri=%X{uri}] %-5level %logger{35} - %m%n
]]></pattern>
<charset>utf8</charset>
</encoder>
<file>${LOG_PATH}/error.log</file>
<rollingPolicy class="ch.qos.logback.core.rolling.TimeBasedRollingPolicy">
<fileNamePattern>${LOG_PATH}/error.%d{yyyy-MM-dd}-%i.log</fileNamePattern>
<MaxHistory>5</MaxHistory>
<timeBasedFileNamingAndTriggeringPolicy
class="ch.qos.logback.core.rolling.SizeAndTimeBasedFNATP">
<maxFileSize>500MB</maxFileSize>
</timeBasedFileNamingAndTriggeringPolicy>
</rollingPolicy>
<filter class="ch.qos.logback.classic.filter.LevelFilter">
<level>ERROR</level>
<onMatch>ACCEPT</onMatch>
<onMismatch>DENY</onMismatch>
</filter>
</appender>
一般是通过org.slf4j.MDC 输出日志时将信息输出
输出日志例子:
image代码:
package com.alibaba.atlas.security;
import org.jboss.logging.MDC;
import org.springframework.util.AntPathMatcher;
import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
/**
* @author:dxc
*/
public class LoggerFilter implements Filter {
private static final String REFERER_KEY = "referer";
private static final String USER_AGENT_KEY = "ua";
private static final String UID_KEY = "uId";
private static final String DOM_KEY = "domId";
private static final String URI_KEY = "uri";
public static final String TRACE_ID_KEY = "traceId";
private static final String USER_AGENT_NAME = "User-Agent";
private static final String REFERER_NAME = "Referer";
/**
* health health2
*/
private static final String HEALTH = "/health";
private static final AntPathMatcher ANT_PATH_MATCHER = new AntPathMatcher();
private static List<String> excludeUrlList = new ArrayList<>();
public LoggerFilter(String excludeUrl) {
if (excludeUrl == null) {
return;
}
String[] split = excludeUrl.split(",");
excludeUrlList = Arrays.asList(split);
}
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain) throws IOException, ServletException {
HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
String requestURI = httpServletRequest.getRequestURI();
if (isExclude(requestURI)) {
chain.doFilter(servletRequest, servletResponse);
return;
}
String userAgent = httpServletRequest.getHeader(USER_AGENT_NAME);
String referer = httpServletRequest.getHeader(REFERER_NAME);
try {
MDC.put(REFERER_KEY, referer);
MDC.put(DOM_KEY, "demo");
MDC.put(UID_KEY, "dxc");
MDC.put(USER_AGENT_KEY, userAgent);
MDC.put(URI_KEY, requestURI);
MDC.put(TRACE_ID_KEY, UUID.randomUUID().toString().replace("-", ""));
chain.doFilter(servletRequest, servletResponse);
} finally {
MDC.clear();
}
}
@Override
public void destroy() {
}
private static boolean isExclude(String uri) {
return matchUrI(excludeUrlList, uri);
}
private static boolean matchUrI(List<String> list, String uri) {
for (String pattern : list) {
if (ANT_PATH_MATCHER.match(pattern, uri)) {
return true;
}
}
return false;
}
}
MDC源码:
package ch.qos.logback.classic.util;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.slf4j.spi.MDCAdapter;
public class LogbackMDCAdapter implements MDCAdapter {
final ThreadLocal<Map<String, String>> copyOnThreadLocal = new ThreadLocal();
private static final int WRITE_OPERATION = 1;
private static final int MAP_COPY_OPERATION = 2;
final ThreadLocal<Integer> lastOperation = new ThreadLocal();
public LogbackMDCAdapter() {
}
private Integer getAndSetLastOperation(int op) {
Integer lastOp = (Integer)this.lastOperation.get();
this.lastOperation.set(op);
return lastOp;
}
private boolean wasLastOpReadOrNull(Integer lastOp) {
return lastOp == null || lastOp == 2;
}
private Map<String, String> duplicateAndInsertNewMap(Map<String, String> oldMap) {
Map<String, String> newMap = Collections.synchronizedMap(new HashMap());
if (oldMap != null) {
synchronized(oldMap) {
newMap.putAll(oldMap);
}
}
this.copyOnThreadLocal.set(newMap);
return newMap;
}
public void put(String key, String val) throws IllegalArgumentException {
if (key == null) {
throw new IllegalArgumentException("key cannot be null");
} else {
Map<String, String> oldMap = (Map)this.copyOnThreadLocal.get();
Integer lastOp = this.getAndSetLastOperation(1);
if (!this.wasLastOpReadOrNull(lastOp) && oldMap != null) {
oldMap.put(key, val);
} else {
Map<String, String> newMap = this.duplicateAndInsertNewMap(oldMap);
newMap.put(key, val);
}
}
}
public void remove(String key) {
if (key != null) {
Map<String, String> oldMap = (Map)this.copyOnThreadLocal.get();
if (oldMap != null) {
Integer lastOp = this.getAndSetLastOperation(1);
if (this.wasLastOpReadOrNull(lastOp)) {
Map<String, String> newMap = this.duplicateAndInsertNewMap(oldMap);
newMap.remove(key);
} else {
oldMap.remove(key);
}
}
}
}
public void clear() {
this.lastOperation.set(1);
this.copyOnThreadLocal.remove();
}
public String get(String key) {
Map<String, String> map = (Map)this.copyOnThreadLocal.get();
return map != null && key != null ? (String)map.get(key) : null;
}
public Map<String, String> getPropertyMap() {
this.lastOperation.set(2);
return (Map)this.copyOnThreadLocal.get();
}
public Set<String> getKeys() {
Map<String, String> map = this.getPropertyMap();
return map != null ? map.keySet() : null;
}
public Map<String, String> getCopyOfContextMap() {
Map<String, String> hashMap = (Map)this.copyOnThreadLocal.get();
return hashMap == null ? null : new HashMap(hashMap);
}
public void setContextMap(Map<String, String> contextMap) {
this.lastOperation.set(1);
Map<String, String> newMap = Collections.synchronizedMap(new HashMap());
newMap.putAll(contextMap);
this.copyOnThreadLocal.set(newMap);
}
}
场景三:数据库事务
AbstractPlatformTransactionManager
public abstract class TransactionSynchronizationManager {
private static final Log logger = LogFactory.getLog(TransactionSynchronizationManager.class);
private static final ThreadLocal<Map<Object, Object>> resources =
new NamedThreadLocal<>("Transactional resources");
private static final ThreadLocal<Set<TransactionSynchronization>> synchronizations =
new NamedThreadLocal<>("Transaction synchronizations");
private static final ThreadLocal<String> currentTransactionName =
new NamedThreadLocal<>("Current transaction name");
private static final ThreadLocal<Boolean> currentTransactionReadOnly =
new NamedThreadLocal<>("Current transaction read-only status");
private static final ThreadLocal<Integer> currentTransactionIsolationLevel =
new NamedThreadLocal<>("Current transaction isolation level");
private static final ThreadLocal<Boolean> actualTransactionActive =
new NamedThreadLocal<>("Actual transaction active");
数据库事务底层也是用ThreadLocal实现,所以如果异步线程去处理一个特殊请求,那么异步任务的事务是失效的,需要在异步任务中自己去控制新的事务,所以同一个事务一般是在同一个线程中完成,不能跨多个线程。
场景四:权限上下文信息
表单填报UserService
@ApplicationScoped
public class UserService extends BaseService {
private static final Logger logger = LoggerFactory.getLogger(UserService.class);
@Inject
SecurityIdentity identity;
@Inject
UserProfileCache userProfileCache;
@Inject
@RestClient
BIService bi;
@ConfigProperty(name = "custom.function.access.folder")
Optional<Boolean> folderAccessOpt;
@ConfigProperty(name = "custom.universe.domains")
Optional<List<String>> universeDomainsOpt;
实现原理:
@RequestScoped
public class SecurityIdentityProxy implements SecurityIdentity {
@Inject
SecurityIdentityAssociation association;
public SecurityIdentityProxy() {
}
public Principal getPrincipal() {
return this.association.getIdentity().getPrincipal();
}
使用Threadlocal实现
package com.alibaba.survey.server.model;
import lombok.Getter;
@Getter
public class UnifiedLoginUser {
private static ThreadLocal<User> threadLocalUser = new InheritableThreadLocal<>();
public static User getCurrentUser() {
return threadLocalUser.get();
}
public static void setThreadLocalUser(User user){
threadLocalUser.set(user);
}
}
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
final class ThreadLocalSecurityContextHolderStrategy implements SecurityContextHolderStrategy {
private static final ThreadLocal<SecurityContext> contextHolder = new ThreadLocal();
ThreadLocalSecurityContextHolderStrategy() {
}
public void clearContext() {
contextHolder.remove();
}
public SecurityContext getContext() {
SecurityContext ctx = (SecurityContext)contextHolder.get();
if (ctx == null) {
ctx = this.createEmptyContext();
contextHolder.set(ctx);
}
return ctx;
}
public void setContext(SecurityContext context) {
Assert.notNull(context, "Only non-null SecurityContext instances are permitted");
contextHolder.set(context);
}
public SecurityContext createEmptyContext() {
return new SecurityContextImpl();
}
}
项目的登录信息一般会放在Threadlocal,方面所有的地方去获取,相当于一个线程内的缓存。
2.底层原理
Thread
ThreadLocal.ThreadLocalMap threadLocals = null;
ThreadLocalMap
static class ThreadLocalMap {
/**
* The entries in this hash map extend WeakReference, using
* its main ref field as the key (which is always a
* ThreadLocal object). Note that null keys (i.e. entry.get()
* == null) mean that the key is no longer referenced, so the
* entry can be expunged from table. Such entries are referred to
* as "stale entries" in the code that follows.
*/
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
/**
* The initial capacity -- MUST be a power of two.
*/
private static final int INITIAL_CAPACITY = 16;
/**
* The table, resized as necessary.
* table.length MUST always be a power of two.
*/
private Entry[] table;
/**
* The number of entries in the table.
*/
private int size = 0;
Threadlocal #set
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
//get
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}
步骤:
1.获取当前线程的引用;
2.获取当前线程对象的threadLocals字段(类似于一个map);
3.不为null则以Threadlocal对象为key,需要设置的值为value;
4.为null则新建map,作为当前线程的threadLocals变量
3.使用不当遇到的问题
疑问1:entry为什么用弱引用
static class ThreadLocalMap {
/**
* The entries in this hash map extend WeakReference, using
* its main ref field as the key (which is always a
* ThreadLocal object). Note that null keys (i.e. entry.get()
* == null) mean that the key is no longer referenced, so the
* entry can be expunged from table. Such entries are referred to
* as "stale entries" in the code that follows.
*/
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
/**
* The initial capacity -- MUST be a power of two.
*/
private static final int INITIAL_CAPACITY = 16;
/**
* The table, resized as necessary.
* table.length MUST always be a power of two.
*/
private Entry[] table;
先用例子分析下java中的引用类型 强引用 >软引用>弱引用>虚引用
package com.alibaba.fetchdata.service.repository;
import java.lang.ref.SoftReference;
import java.lang.ref.WeakReference;
public class ReferenceTest {
public static void main(String[] args) {
//强引用
Object strongObj = new Object();
System.gc();
System.out.println("强引用gc后:"+strongObj);
//软引用
SoftReference<Object> softReference = new SoftReference<>(new Object());
System.out.println("软引用gc前:"+softReference.get());
System.gc();
System.out.println("软引用gc后:"+softReference.get());
try {
Object[] softArr = new Object[1024*1024*1024];
}catch (Throwable throwable){
System.out.println(throwable);
}
System.out.println("软引用内存不足后:"+softReference.get());
//强引用+软引用
// Object softObj = new Object(); //有强引用指向,不会回收
// SoftReference<Object> softReference = new SoftReference<>(softObj);
// System.out.println("软引用gc前:"+softReference.get());
// System.gc();
// System.out.println("软引用gc后:"+softReference.get());
// try {
// Object[] softArr = new Object[1024*1024*1024];
// }catch (Throwable throwable){
// System.out.println(throwable);
// }
// System.out.println("软引用内存不足后:"+softReference.get());
//弱引用
WeakReference<Object> objectWeakReference = new WeakReference<>(new Object());
System.out.println("弱引用gc前:"+objectWeakReference.get());
System.gc();
System.out.println("弱引用gc后:"+objectWeakReference.get());
//弱引用
// Object weakObject = new Object();
// WeakReference<Object> objectWeakReference = new WeakReference<>(weakObject);
// System.out.println("弱引用gc前:"+objectWeakReference.get());
// System.gc();
// System.out.println("弱引用gc后:"+objectWeakReference.get());
}
}
分析Threadlocal:
Threadlocal<Object> threadlocal = new Threadlocal<Object>();
Thread -> ThreadLocal.ThreadLocalMap -> Entry[] -> Entry -> key(threadLocal对象)和value
所以local会被强引用threadlocal和Thread弱引用key所引用
当threadlocal不可访问 threadlocal=null时,垃圾回收器收集时,因为key只有弱引用关联,所以对应的local对象会被回收;如果不设计成弱引用,则不会被回收器回收。
参考文章:https://zhuanlan.zhihu.com/p/304240519
疑问2:到底需要不需要remove
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class ThreadLocalDemo {
private static ExecutorService executorService = Executors.newFixedThreadPool(3);
private static ThreadLocal<String> threadLocal = new ThreadLocal();
public static void main(String[] args) {
for (int i = 0; i < 10; i++) {
int j = i + 1;
executorService.execute(new Runnable() {
@Override
public void run() {
System.out.println(Thread.currentThread().getName() + "---设值前threadLocal=" + threadLocal.get());
threadLocal.set(Thread.currentThread().getName() + " and 第" + j + "个请求");
System.out.println(Thread.currentThread().getName() + "---设值后threadLocal=" + threadLocal.get());
}
});
}
executorService.shutdown();
}
}
上述代码有无问题:
问题原因:线程池复用技术,核心线程基本不会销毁
解决方案:try finally 清理
疑问2:代码中需要异步执行的代码怎么处理?
解决方式1:执行异步任务前保存上下文信息,然后执行异步任务时初始化上下文
以某个项目批量更新为例,需要异步执行更新逻辑,源码如下:
public void doBatchUpdate(BatchUpdateDataDTO batchUpdateData) {
String tableName = batchUpdateData.getTableName();
Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
UserInfo userInfo = new UserInfo(authentication);
String taskId = UUID.randomUUID().toString().replace("-", "");
TaskDAO taskDAO = new TaskDAO(taskId, TaskTypeEnum.BATCH_MODIFY_PT, tableName, TaskStatusEnum.RUNNING, userInfo.getUserId());
TaskRepository.insertTask(taskDAO);
batchUpdateData.setTaskId(taskId);
batchUpdateData.setAuthentication(authentication);
//异步任务,新的线程
disruptorTaskService.sendNotify(EventTypeEnum.BATCH_UPDATE_TABLE, batchUpdateData);
}
异步处理逻辑:
image缺点:需要手动设置,当需要的Threadlocal变量比较多时,将是灾难
解决方式2:使用InheritableThreadLocal,原理其实和方式1类似
Thread 初始化逻辑
private void init(ThreadGroup g, Runnable target, String name,
long stackSize, AccessControlContext acc,
boolean inheritThreadLocals) {
if (name == null) {
throw new NullPointerException("name cannot be null");
}
this.name = name;
Thread parent = currentThread();
SecurityManager security = System.getSecurityManager();
if (g == null) {
/* Determine if it's an applet or not */
/* If there is a security manager, ask the security manager
what to do. */
if (security != null) {
g = security.getThreadGroup();
}
/* If the security doesn't have a strong opinion of the matter
use the parent thread group. */
if (g == null) {
g = parent.getThreadGroup();
}
}
/* checkAccess regardless of whether or not threadgroup is
explicitly passed in. */
g.checkAccess();
/*
* Do we have the required permissions?
*/
if (security != null) {
if (isCCLOverridden(getClass())) {
security.checkPermission(SUBCLASS_IMPLEMENTATION_PERMISSION);
}
}
g.addUnstarted();
this.group = g;
this.daemon = parent.isDaemon();
this.priority = parent.getPriority();
if (security == null || isCCLOverridden(parent.getClass()))
this.contextClassLoader = parent.getContextClassLoader();
else
this.contextClassLoader = parent.contextClassLoader;
this.inheritedAccessControlContext =
acc != null ? acc : AccessController.getContext();
this.target = target;
setPriority(priority);
if (inheritThreadLocals && parent.inheritableThreadLocals != null)
this.inheritableThreadLocals =
ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
/* Stash the specified stack size in case the VM cares */
this.stackSize = stackSize;
/* Set thread ID */
tid = nextThreadID();
}
在线程中开启新的线程,新的线程为子线程,另一个为父线程,子线程初始化时,会将父线程的inheritableThreadLocals数据拷贝过来,这样子线程就有了父线程的上下文信息了,和解决方式1异曲同工之妙,但是框架会自动给你设置好,避免自己手动设置造成的各种问题,spring-security也提供了这种结构:
class InheritableThreadLocalSecurityContextHolderStrategy implements SecurityContextHolderStrategy {
private static final ThreadLocal<SecurityContext> contextHolder = new InheritableThreadLocal();
InheritableThreadLocalSecurityContextHolderStrategy() {
}
public void clearContext() {
contextHolder.remove();
}
public SecurityContext getContext() {
SecurityContext ctx = (SecurityContext)contextHolder.get();
if (ctx == null) {
ctx = this.createEmptyContext();
contextHolder.set(ctx);
}
return ctx;
}