Junit单元测试不支持多线程测试--原因分析和问题解决
问题现象
在测试的代码中,如果是想要测试多线程的场景,将会出现失败,得不到任何测试结果。
作者在测试Spring Cloud Gateway时,用到了RedisRateLimiter,即服务网关限流的功能,用了令牌桶的方式来限流,一秒钟只允许10个请求通过。
写完代码后就打算测试一下,当然使用Jmeter来测试也是可以的,但是作者想通过测试代码的方式来测试,所以就遇到单元测试不支持多线程的问题。
先看代码:
package com.zz.cloud.gateway.spring.test;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.client.TestRestTemplate;
import org.springframework.boot.web.server.LocalServerPort;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import java.util.concurrent.CountDownLatch;
import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class GatewayRateLimiterTest {
@LocalServerPort
private int port;
@Autowired
private TestRestTemplate restTemplate;
private int threadCnt = 1000;
private CountDownLatch latch = new CountDownLatch(threadCnt);
private int okCnt = 0;
private int failCnt = 0;
private int expCnt = 0;
private int returnCnt = 0;
class Runner implements Runnable {
@Override
public void run() {
System.out.println("【当前线程ID】:" + Thread.currentThread().getId());
try {
ResponseEntity responseEntity = restTemplate.getForEntity("http://localhost:" + port + "/hello/tst/port?token=111", String.class);
if(responseEntity.getStatusCode().equals(HttpStatus.OK)) {
okCnt ++;
System.out.println(Thread.currentThread().getId() + ": " + responseEntity.getBody().toString());
} else {
failCnt ++;
}
} catch (Exception e) {
expCnt ++;
} finally {
returnCnt ++;
}
}
}
@Test
public void testRateLimiter() throws Exception {
for (int i = 0; i < threadCnt; i++) {
try {
(new Thread(new Runner(), "JUNIT多线程测试")).start();
} catch (OutOfMemoryError error) {
System.out.println("OutOfMemoryError: " + i);
latch.countDown(); // 执行完毕,计数器减1
}
}
System.out.println("ok cnt: " + okCnt);
System.out.println("fail cnt: " + failCnt);
System.out.println("exception cnt: " + expCnt);
System.out.println("return cnt: " + returnCnt);
assertThat(okCnt).isBetween(10, 15);
}
}
测试运行后,输出结果:
【当前线程ID】:27
【当前线程ID】:28
... //省略
【当前线程ID】:1028
【当前线程ID】:1029
【当前线程ID】:1030
【当前线程ID】:1031
ok cnt: 0
fail cnt: 0
exception cnt: 0
return cnt: 0
java.lang.AssertionError:
Expecting:
<0>
to be between:
[10, 15]
原因分析
TestRunner源码片断:
public class TestRunner extends BaseTestRunner {
private ResultPrinter fPrinter;
public static final int SUCCESS_EXIT = 0;
public static final int FAILURE_EXIT = 1;
public static final int EXCEPTION_EXIT = 2;
...
public static void main(String args[]) {
TestRunner aTestRunner = new TestRunner();
try {
TestResult r = aTestRunner.start(args);
if (!r.wasSuccessful()) {
System.exit(FAILURE_EXIT);
}
System.exit(SUCCESS_EXIT);
} catch (Exception e) {
System.err.println(e.getMessage());
System.exit(EXCEPTION_EXIT);
}
}
...
}
TestResult类源码片断
public class TestResult {
protected List<TestFailure> fFailures;
protected List<TestFailure> fErrors;
protected List<TestListener> fListeners;
protected int fRunTests;
private boolean fStop;
public TestResult() {
fFailures = new ArrayList<TestFailure>();
fErrors = new ArrayList<TestFailure>();
fListeners = new ArrayList<TestListener>();
fRunTests = 0;
fStop = false;
}
...
/**
* Returns whether the entire test was successful or not.
*/
public synchronized boolean wasSuccessful() {
return failureCount() == 0 && errorCount() == 0;
}
}
在这里我们明显可以看到:
当aTestRunner调用start方法后不会去等待子线程执行完毕在关闭主线程,而是直接调用TestResult.wasSuccessful()方法,然后执行System.exit
TestResult r = aTestRunner.start(args);
if (!r.wasSuccessful()) {
System.exit(FAILURE_EXIT);
}
System.exit(SUCCESS_EXIT);
即结束当前运行的jvm虚拟机,所以使用junit测试多线程就会等不到结果就结束执行了,最后达不到测试应有的结果;
问题解决
解决办法就是让主线程等待子线程全部运行结束后再结束,测试就会正常
解决方式1
Thread.sleep();
让主线程等待一定的时间,这个办法不是很好,因为你不知道所有子线程结束的时间,只能猜,所以不准确,效果可能不是很好。
解决方式2
线程计数器 CountDownLatch
直接给出代码:
package com.zz.cloud.gateway.spring.test;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.client.TestRestTemplate;
import org.springframework.boot.web.server.LocalServerPort;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import java.util.concurrent.CountDownLatch;
import static org.assertj.core.api.Assertions.assertThat;
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class GatewayRateLimiterTest {
@LocalServerPort
private int port;
@Autowired
private TestRestTemplate restTemplate;
private int threadCnt = 1000;
private CountDownLatch latch = new CountDownLatch(threadCnt);
private int okCnt = 0;
private int failCnt = 0;
private int expCnt = 0;
private int returnCnt = 0;
protected synchronized void incOkCnt() {
this.okCnt ++;
}
protected synchronized void incFailCnt() {
this.failCnt ++;
}
protected synchronized void incExpCnt() {
this.expCnt ++;
}
protected synchronized void incReturnCnt() {
this.returnCnt ++;
}
class Runner implements Runnable {
@Override
public void run() {
System.out.println("【当前线程ID】:" + Thread.currentThread().getId());
try {
ResponseEntity responseEntity = restTemplate.getForEntity("http://localhost:" + port + "/hello/tst/port?token=111", String.class);
if(responseEntity.getStatusCode().equals(HttpStatus.OK)) {
incOkCnt();
System.out.println(Thread.currentThread().getId() + ": " + responseEntity.getBody().toString());
} else {
incFailCnt();
}
} catch (Exception e) {
incExpCnt();
} finally {
incReturnCnt();
}
latch.countDown(); // 执行完毕,计数器减1
}
}
@Test
public void testRateLimiter() throws Exception {
for (int i = 0; i < threadCnt; i++) {
try {
(new Thread(new Runner(), "JUNIT多线程测试")).start();
} catch (OutOfMemoryError error) {
System.out.println("OutOfMemoryError: " + i);
latch.countDown(); // 执行完毕,计数器减1
}
}
try {
latch.await(); // 主线程等待
System.out.println("ok cnt: " + okCnt);
System.out.println("fail cnt: " + failCnt);
System.out.println("exception cnt: " + expCnt);
System.out.println("return cnt: " + returnCnt);
assertThat(okCnt).isBetween(10, 15);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
如上,即可完美解决测试时不支持多线程的问题。
附:
如果有细心的读者,在执行上面的测试代码时,会发现上面okCnt、failCnt、expCnt、returnCnt的变量的自增,使用了synchronized修饰的方法,这是为了修改变量时的线程安全。不然可能会得不到正确的统计数量。
如果你的线程数量不是1000,而是更大的值,比如10000,那就可能会产生OOM的错误,这个错误的产生是因为本地内存可能不够创建native线程,具体的分析可以参考:
由多线程内存溢出产生的实战分析