之前简单介绍过 ThreadLocal ,但是其中有个问题就是当一个请求中使用到线程池时,无法将主线程中ThreadLocal中的值传递进去,这次我们就看下怎么解决这个问题
比较直接的的方法就是包装一下Runnable或Callable,在创建的时候将主线程中ThreadLocal对应内容传递保存进去,之后执行的时候再取出来重新赋值到对应ThreadLocal中,使用之后再清理掉即可,大致样子如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 public class SimpleThreadLocalTest { private Executor executor = Executors.newSingleThreadExecutor(); public static ThreadLocal<String> USER_NAME_THREAD_LOCAL = new ThreadLocal <>(); @Test public void test () { USER_NAME_THREAD_LOCAL.set("zheng" ); executor.execute(new ThreadLocalRunnable ()); } static class ThreadLocalRunnable implements Runnable { private String userName; public ThreadLocalRunnable () { this .userName = USER_NAME_THREAD_LOCAL.get(); } @Override public void run () { try { USER_NAME_THREAD_LOCAL.set(userName); System.out.println("userName: " + USER_NAME_THREAD_LOCAL.get()); } finally { USER_NAME_THREAD_LOCAL.remove(); } } } }
这里很明显可以看出来,自定义的Runnable实现与系统中定义的ThreadLocal进行了强耦合,当有更多的ThreadLocal时会使代码很难维护,比较幸运的是,这种工具已经有了比较好的开源实现,这里就介绍下transmittable-thread-local
使用 先来看下它的使用方法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 public class TtlMain { ExecutorService executorService = Executors.newSingleThreadExecutor(); TransmittableThreadLocal<String> nameThreadLocal = new TransmittableThreadLocal <>(); @Test public void test () { nameThreadLocal.set("zheng" ); Runnable r = TtlRunnable.get(() -> { final String s = nameThreadLocal.get(); System.out.println(s); }); executorService.execute(r); } }
如果觉得每次使用TtlRunnable进行包装比较麻烦,可以使用它提供的线程池进行包装
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 public class TtlMain { ExecutorService executorService = TtlExecutors.getTtlExecutorService(Executors.newSingleThreadExecutor()); TransmittableThreadLocal<String> nameThreadLocal = new TransmittableThreadLocal <>(); @Test public void test () { nameThreadLocal.set("zheng" ); executorService.execute(() -> { final String s = nameThreadLocal.get(); System.out.println(s); }); } }
原理 原理部分相对啰嗦些,着急知道结果的可以直接看小结部分
TransmittableThreadLocal 这里面有个很重要的类:TransmittableThreadLocal,先来看下它的get, set方法部分源码实现
这里可以看到,它继承了InheritableThreadLocal,在 get 和 set 时直接调用父类 InheritableThreadLocal 的方法,就是在set时多了一步,会将TransmittableThreadLocal的实例统一保存起来,这个后面在进行跨线程赋值传递的时候会用到,不需要到处去找都有哪些TransmittableThreadLocal实例的数据要进行复制
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 public class TransmittableThreadLocal <T> extends InheritableThreadLocal <T> implements TtlCopier <T> { @Override public final T get () { T value = super .get(); if (disableIgnoreNullValueSemantics || null != value) addThisToHolder(); return value; } @Override public final void set (T value) { if (!disableIgnoreNullValueSemantics && null == value) { remove(); } else { super .set(value); addThisToHolder(); } } private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder = new InheritableThreadLocal <WeakHashMap<TransmittableThreadLocal<Object>, ?>>() { @Override protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() { return new WeakHashMap <TransmittableThreadLocal<Object>, Object>(); } @Override protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) { return new WeakHashMap <TransmittableThreadLocal<Object>, Object>(parentValue); } }; private void addThisToHolder () { if (!holder.get().containsKey(this )) { holder.get().put((TransmittableThreadLocal<Object>) this , null ); } } }
同时需要注意的是,如果我们有深拷贝的需求,可以实现一个TransmittableThreadLocal子类,重写它的copy方法即可
1 2 3 4 5 6 7 TransmittableThreadLocal<String> nameThreadLocal = new TransmittableThreadLocal <String>() { @Override public String copy (String parentValue) { return super .copy(parentValue); } };
Transmitter 除此之外,还有一个很重要的类:Transmitter,它是TransmittableThreadLocal的一个内部类,其中的方法都是静态方法,主要用来在线程切换时进行数据的快照保存(capture)、重放(replay)和恢复(restore),在看源码之前先看一下使用的例子
利用Transmitter将主线程的数据快照进行记录
在子线程/线程池中执行时,将记录的快照数据进行重新设置到当前线程,并将当前子线程的数据进行备份
执行完毕后将备份的数据恢复到当前线程数据中
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 public void test1 () { Object captured = Transmitter.capture(); Callable<String> callable = () -> { Object backup = Transmitter.replay(captured); try { System.out.println("Hello" ); return "World" ; } finally { Transmitter.restore(backup); } }; executorService.submit(callable); }
下面依次看下这几个步骤的实现
Transmitter.capture 这个方法本身比较简单,它是在主线程中执行的,主要就是将之前记录到的所有TransmittableThreadLocal实例数据转成对应map进行返回
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 public static class Transmitter { public static Object capture () { return new Snapshot (captureTtlValues(), captureThreadLocalValues()); } private static HashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues () { HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap <TransmittableThreadLocal<Object>, Object>(); for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) { ttl2Value.put(threadLocal, threadLocal.copyValue()); } return ttl2Value; } private static HashMap<ThreadLocal<Object>, Object> captureThreadLocalValues () { final HashMap<ThreadLocal<Object>, Object> threadLocal2Value = new HashMap <ThreadLocal<Object>, Object>(); for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) { final ThreadLocal<Object> threadLocal = entry.getKey(); final TtlCopier<Object> copier = entry.getValue(); threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get())); } return threadLocal2Value; } } private static class Snapshot { final HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value; final HashMap<ThreadLocal<Object>, Object> threadLocal2Value; private Snapshot (HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value, HashMap<ThreadLocal<Object>, Object> threadLocal2Value) { this .ttl2Value = ttl2Value; this .threadLocal2Value = threadLocal2Value; } }
Transmitter.replay 这个方法是在子线程/线程池中执行的,用于将快照中的数据设置到当前线程中,并将当前线程中的数据进行备份返回
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 public static Object replay (@NonNull Object captured) { final Snapshot capturedSnapshot = (Snapshot) captured; return new Snapshot (replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value)); } @NonNull private static HashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues (@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) { HashMap<TransmittableThreadLocal<Object>, Object> backup = new HashMap <TransmittableThreadLocal<Object>, Object>(); for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) { TransmittableThreadLocal<Object> threadLocal = iterator.next(); backup.put(threadLocal, threadLocal.get()); if (!captured.containsKey(threadLocal)) { iterator.remove(); threadLocal.superRemove(); } } setTtlValuesTo(captured); doExecuteCallback(true ); return backup; } private static void setTtlValuesTo (HashMap<TransmittableThreadLocal<Object>, Object> ttlValues) { for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) { TransmittableThreadLocal<Object> threadLocal = entry.getKey(); threadLocal.set(entry.getValue()); } }
Transmitter.restore 这个方法是在子线程/线程池中执行的,用于在业务逻辑处理完成后,将子线程之前的线程相关数据进行恢复,也即是进行使用后的清理恢复工作
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 public static void restore (@NonNull Object backup) { final Snapshot backupSnapshot = (Snapshot) backup; restoreTtlValues(backupSnapshot.ttl2Value); restoreThreadLocalValues(backupSnapshot.threadLocal2Value); } private static void restoreTtlValues (HashMap<TransmittableThreadLocal<Object>, Object> backup) { doExecuteCallback(false ); for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) { TransmittableThreadLocal<Object> threadLocal = iterator.next(); if (!backup.containsKey(threadLocal)) { iterator.remove(); threadLocal.superRemove(); } } setTtlValuesTo(backup); } private static void setTtlValuesTo (@NonNull HashMap<TransmittableThreadLocal<Object>, Object> ttlValues) { for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) { TransmittableThreadLocal<Object> threadLocal = entry.getKey(); threadLocal.set(entry.getValue()); } }
有了上面的这些类和方法进行支撑,TtlRunnable或者TtlExecutors等进行使用时就比较容易了,我们简单看一下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 public final class TtlRunnable implements Runnable , TtlWrapper<Runnable>, TtlEnhanced, TtlAttachments { private final AtomicReference<Object> capturedRef; private final Runnable runnable; private final boolean releaseTtlValueReferenceAfterRun; private TtlRunnable (Runnable runnable, boolean releaseTtlValueReferenceAfterRun) { this .capturedRef = new AtomicReference <Object>(capture()); this .runnable = runnable; this .releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun; } @Override public void run () { final Object captured = capturedRef.get(); if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null )) { throw new IllegalStateException ("TTL value reference is released after run!" ); } final Object backup = replay(captured); try { runnable.run(); } finally { restore(backup); } } }
其中我们还可以发现一个点就是Transmitter除了处理TransmittableThreadLocal中的holder,还用同样的方法处理使用它的一个静态成员变量threadLocalHolder
1 2 private static volatile WeakHashMap<ThreadLocal<Object>, TtlCopier<Object>> threadLocalHolder = new WeakHashMap <ThreadLocal<Object>, TtlCopier<Object>>();
这个threadLocalHolder的作用是对于在项目中使用了ThreadLocal,但是却无法替换为TransmittableThreadLocal的情况,可以使用Transmitter提供的注册方法,将项目中的threadLocal注册到它的threadLocalHolder中,后面进行capture等操作时holder和threadLocalHolder都会进行处理使用
1 2 3 4 5 6 Transmitter.registerThreadLocalWithShadowCopier(threadLocal); Transmitter.registerThreadLocal(threadLocal, copyLambda); Transmitter.unregisterThreadLocal(threadLocal);
对应的部分源码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 public static <T> boolean registerThreadLocalWithShadowCopier (@NonNull ThreadLocal<T> threadLocal) { return registerThreadLocal(threadLocal, (TtlCopier<T>) shadowCopier, false ); } public static <T> boolean registerThreadLocal (@NonNull ThreadLocal<T> threadLocal, @NonNull TtlCopier<T> copier, boolean force) { if (threadLocal instanceof TransmittableThreadLocal) { logger.warning("register a TransmittableThreadLocal instance, this is unnecessary!" ); return true ; } synchronized (threadLocalHolderUpdateLock) { if (!force && threadLocalHolder.containsKey(threadLocal)) return false ; WeakHashMap<ThreadLocal<Object>, TtlCopier<Object>> newHolder = new WeakHashMap <ThreadLocal<Object>, TtlCopier<Object>>(threadLocalHolder); newHolder.put((ThreadLocal<Object>) threadLocal, (TtlCopier<Object>) copier); threadLocalHolder = newHolder; return true ; } }
小结 通过上面的分析,可以发现核心类其实就是两个:TransmittableThreadLocal
和 Transmitter
在使用 TransmittableThreadLocal 时,它在将值保存到父类 InheritableThreadLocal 中的同时,会将当前的 TransmittableThreadLocal 实际进行存储,这样使用完成后,它自己就会维护一份所有用到的TransmittableThreadLocal 实例,不管它是用户信息的,还是其他信息的实例
有了上面维护的信息,就可以借助Transmitter来对其中的数据进行操作,一般操作步骤如下
主线程:调用Transmitter.capture,将当前主线程中的所有TransmittableThreadLocal和值进行快照保存(Map结构,结果要作为value进行存储,否则其他线程取不到TransmittableThreadLocal的value值)
子线程:调用Transmitter.replay,用于将之前保存的所有TransmittableThreadLocal实例及其值重新设置一下(需要借助之前保存的map结构,因为TransmittableThreadLocal中的数据是线程隔离的),并将当前线程的所有TransmittableThreadLocal实例进行备份返回
子线程:业务代码执行完毕之后调用Transmitter.restore,用于将之前备份的数据进行恢复,原理同replay方法