Test

Junit单元测试不支持多线程测试--原因分析和问题解决

2020-06-30  本文已影响0人  Zal哥哥

问题现象

在测试的代码中,如果是想要测试多线程的场景,将会出现失败,得不到任何测试结果。

作者在测试Spring Cloud Gateway时,用到了RedisRateLimiter,即服务网关限流的功能,用了令牌桶的方式来限流,一秒钟只允许10个请求通过。

写完代码后就打算测试一下,当然使用Jmeter来测试也是可以的,但是作者想通过测试代码的方式来测试,所以就遇到单元测试不支持多线程的问题。

先看代码:

package com.zz.cloud.gateway.spring.test;

import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.client.TestRestTemplate;
import org.springframework.boot.web.server.LocalServerPort;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;

import java.util.concurrent.CountDownLatch;

import static org.assertj.core.api.Assertions.assertThat;

@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class GatewayRateLimiterTest {

    @LocalServerPort
    private int port;

    @Autowired
    private TestRestTemplate restTemplate;

    private int threadCnt = 1000;
    private CountDownLatch latch = new CountDownLatch(threadCnt);
    private int okCnt = 0;
    private int failCnt = 0;
    private int expCnt = 0;
    private int returnCnt = 0;

    class Runner implements Runnable {
        @Override
        public void run() {
            System.out.println("【当前线程ID】:" + Thread.currentThread().getId());

            try {
                ResponseEntity responseEntity = restTemplate.getForEntity("http://localhost:" + port + "/hello/tst/port?token=111", String.class);
                if(responseEntity.getStatusCode().equals(HttpStatus.OK)) {
                    okCnt ++;
                    System.out.println(Thread.currentThread().getId() + ": " + responseEntity.getBody().toString());
                } else {
                    failCnt ++;
                }
            } catch (Exception e) {
                expCnt ++;
            } finally {
                returnCnt ++;
            }
        }
    }

    @Test
    public void testRateLimiter() throws Exception {
        for (int i = 0; i < threadCnt; i++) {
            try {
                (new Thread(new Runner(), "JUNIT多线程测试")).start();
            } catch (OutOfMemoryError error) {
                System.out.println("OutOfMemoryError: " + i);
                latch.countDown();  // 执行完毕,计数器减1
            }
        }

        System.out.println("ok cnt: " + okCnt);
        System.out.println("fail cnt: " + failCnt);
        System.out.println("exception cnt: " + expCnt);
        System.out.println("return cnt: " + returnCnt);

        assertThat(okCnt).isBetween(10, 15);
    }
}

测试运行后,输出结果:

【当前线程ID】:27
【当前线程ID】:28
  ... //省略
【当前线程ID】:1028
【当前线程ID】:1029
【当前线程ID】:1030
【当前线程ID】:1031
ok cnt: 0
fail cnt: 0
exception cnt: 0
return cnt: 0

java.lang.AssertionError: 
Expecting:
 <0>
to be between:
 [10, 15]

原因分析

TestRunner源码片断:

public class TestRunner extends BaseTestRunner {
    private ResultPrinter fPrinter;

    public static final int SUCCESS_EXIT = 0;
    public static final int FAILURE_EXIT = 1;
    public static final int EXCEPTION_EXIT = 2;

    ...

    public static void main(String args[]) {
        TestRunner aTestRunner = new TestRunner();
        try {
            TestResult r = aTestRunner.start(args);
            if (!r.wasSuccessful()) {
                System.exit(FAILURE_EXIT);
            }
            System.exit(SUCCESS_EXIT);
        } catch (Exception e) {
            System.err.println(e.getMessage());
            System.exit(EXCEPTION_EXIT);
        }
    }

    ...
}

TestResult类源码片断

public class TestResult {
    protected List<TestFailure> fFailures;
    protected List<TestFailure> fErrors;
    protected List<TestListener> fListeners;
    protected int fRunTests;
    private boolean fStop;

    public TestResult() {
        fFailures = new ArrayList<TestFailure>();
        fErrors = new ArrayList<TestFailure>();
        fListeners = new ArrayList<TestListener>();
        fRunTests = 0;
        fStop = false;
    }

    ...

    /**
     * Returns whether the entire test was successful or not.
     */
    public synchronized boolean wasSuccessful() {
        return failureCount() == 0 && errorCount() == 0;
    }
}

在这里我们明显可以看到:
当aTestRunner调用start方法后不会去等待子线程执行完毕在关闭主线程,而是直接调用TestResult.wasSuccessful()方法,然后执行System.exit

            TestResult r = aTestRunner.start(args);
            if (!r.wasSuccessful()) {
                System.exit(FAILURE_EXIT);
            }
            System.exit(SUCCESS_EXIT);

即结束当前运行的jvm虚拟机,所以使用junit测试多线程就会等不到结果就结束执行了,最后达不到测试应有的结果;

问题解决

解决办法就是让主线程等待子线程全部运行结束后再结束,测试就会正常

解决方式1

Thread.sleep();
让主线程等待一定的时间,这个办法不是很好,因为你不知道所有子线程结束的时间,只能猜,所以不准确,效果可能不是很好。

解决方式2

线程计数器 CountDownLatch

直接给出代码:

package com.zz.cloud.gateway.spring.test;

import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.client.TestRestTemplate;
import org.springframework.boot.web.server.LocalServerPort;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;

import java.util.concurrent.CountDownLatch;

import static org.assertj.core.api.Assertions.assertThat;

@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
public class GatewayRateLimiterTest {

    @LocalServerPort
    private int port;

    @Autowired
    private TestRestTemplate restTemplate;

    private int threadCnt = 1000;
    private CountDownLatch latch = new CountDownLatch(threadCnt);
    private int okCnt = 0;
    private int failCnt = 0;
    private int expCnt = 0;
    private int returnCnt = 0;

    protected synchronized void incOkCnt() {
        this.okCnt ++;
    }

    protected synchronized void incFailCnt() {
        this.failCnt ++;
    }

    protected synchronized void incExpCnt() {
        this.expCnt ++;
    }

    protected synchronized void incReturnCnt() {
        this.returnCnt ++;
    }

    class Runner implements Runnable {
        @Override
        public void run() {
            System.out.println("【当前线程ID】:" + Thread.currentThread().getId());

            try {
                ResponseEntity responseEntity = restTemplate.getForEntity("http://localhost:" + port + "/hello/tst/port?token=111", String.class);
                if(responseEntity.getStatusCode().equals(HttpStatus.OK)) {
                    incOkCnt();
                    System.out.println(Thread.currentThread().getId() + ": " + responseEntity.getBody().toString());
                } else {
                    incFailCnt();
                }
            } catch (Exception e) {
                incExpCnt();
            } finally {
                incReturnCnt();
            }

            latch.countDown(); // 执行完毕,计数器减1
        }
    }

    @Test
    public void testRateLimiter() throws Exception {
        for (int i = 0; i < threadCnt; i++) {
            try {
                (new Thread(new Runner(), "JUNIT多线程测试")).start();
            } catch (OutOfMemoryError error) {
                System.out.println("OutOfMemoryError: " + i);
                latch.countDown();  // 执行完毕,计数器减1
            }
        }

        try {
            latch.await(); // 主线程等待

            System.out.println("ok cnt: " + okCnt);
            System.out.println("fail cnt: " + failCnt);
            System.out.println("exception cnt: " + expCnt);
            System.out.println("return cnt: " + returnCnt);

            assertThat(okCnt).isBetween(10, 15);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}

如上,即可完美解决测试时不支持多线程的问题。

附:

如果有细心的读者,在执行上面的测试代码时,会发现上面okCnt、failCnt、expCnt、returnCnt的变量的自增,使用了synchronized修饰的方法,这是为了修改变量时的线程安全。不然可能会得不到正确的统计数量。

如果你的线程数量不是1000,而是更大的值,比如10000,那就可能会产生OOM的错误,这个错误的产生是因为本地内存可能不够创建native线程,具体的分析可以参考:
由多线程内存溢出产生的实战分析

上一篇 下一篇

猜你喜欢

热点阅读