1
2
3
4
5 package mockit.integration.junit5;
6
7 import static mockit.internal.util.StackTrace.filterStackTrace;
8
9 import edu.umd.cs.findbugs.annotations.NonNull;
10 import edu.umd.cs.findbugs.annotations.Nullable;
11
12 import java.lang.reflect.Method;
13 import java.util.Arrays;
14 import java.util.stream.Collectors;
15
16 import mockit.Capturing;
17 import mockit.Injectable;
18 import mockit.Mocked;
19 import mockit.Tested;
20 import mockit.integration.TestRunnerDecorator;
21 import mockit.internal.expectations.RecordAndReplayExecution;
22 import mockit.internal.state.SavePoint;
23 import mockit.internal.state.TestRun;
24 import mockit.internal.util.Utilities;
25
26 import org.junit.jupiter.api.BeforeAll;
27 import org.junit.jupiter.api.BeforeEach;
28 import org.junit.jupiter.api.Nested;
29 import org.junit.jupiter.api.extension.AfterAllCallback;
30 import org.junit.jupiter.api.extension.AfterEachCallback;
31 import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
32 import org.junit.jupiter.api.extension.BeforeAllCallback;
33 import org.junit.jupiter.api.extension.BeforeEachCallback;
34 import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
35 import org.junit.jupiter.api.extension.ExtensionContext;
36 import org.junit.jupiter.api.extension.ParameterContext;
37 import org.junit.jupiter.api.extension.ParameterResolver;
38 import org.junit.jupiter.api.extension.TestExecutionExceptionHandler;
39 import org.junit.jupiter.api.extension.TestInstancePostProcessor;
40
41 @SuppressWarnings("Since15")
42 public final class JMockitExtension extends TestRunnerDecorator implements BeforeAllCallback, AfterAllCallback,
43 TestInstancePostProcessor, BeforeEachCallback, AfterEachCallback, BeforeTestExecutionCallback,
44 AfterTestExecutionCallback, ParameterResolver, TestExecutionExceptionHandler {
45 @Nullable
46 private SavePoint savePointForTestClass;
47 @Nullable
48 private SavePoint savePointForTest;
49 @Nullable
50 private SavePoint savePointForTestMethod;
51 @Nullable
52 private Throwable thrownByTest;
53 private Object[] parameterValues;
54 private ParamValueInitContext initContext = new ParamValueInitContext(null, null, null,
55 "No callbacks have been processed, preventing parameter population");
56
57 @Override
58 public void beforeAll(@NonNull ExtensionContext context) {
59 if (isRegularTestClass(context)) {
60 @Nullable
61 Class<?> testClass = context.getTestClass().orElse(null);
62 savePointForTestClass = new SavePoint();
63 TestRun.setCurrentTestClass(testClass);
64
65 if (testClass == null) {
66 initContext = new ParamValueInitContext(null, null, null,
67 "@BeforeAll setup failed to acquire 'Class' of test");
68 return;
69 }
70
71
72 Object testInstance = context.getTestInstance().orElse(null);
73 Method beforeAllMethod = Utilities.getAnnotatedDeclaredMethod(testClass, BeforeAll.class);
74 if (testInstance == null) {
75 initContext = new ParamValueInitContext(null, testClass, beforeAllMethod,
76 "@BeforeAll setup failed to acquire instance of test class");
77 return;
78 }
79
80 if (beforeAllMethod != null) {
81 initContext = new ParamValueInitContext(testInstance, testClass, beforeAllMethod, null);
82 parameterValues = createInstancesForAnnotatedParameters(testInstance, beforeAllMethod, null);
83 }
84 }
85 }
86
87 private static boolean isRegularTestClass(@NonNull ExtensionContext context) {
88 Class<?> testClass = context.getTestClass().orElse(null);
89 return testClass != null && !testClass.isAnnotationPresent(Nested.class);
90 }
91
92 @Override
93 public void postProcessTestInstance(@NonNull Object testInstance, @NonNull ExtensionContext context) {
94 if (isRegularTestClass(context)) {
95 TestRun.enterNoMockingZone();
96
97 try {
98 handleMockFieldsForWholeTestClass(testInstance);
99 } finally {
100 TestRun.exitNoMockingZone();
101 }
102
103 TestRun.setRunningIndividualTest(testInstance);
104 }
105 }
106
107 @Override
108 public void beforeEach(@NonNull ExtensionContext context) {
109 Object testInstance = context.getTestInstance().orElse(null);
110 Class<?> testClass = context.getTestClass().orElse(null);
111 if (testInstance == null) {
112 initContext = new ParamValueInitContext(null, null, null,
113 "@BeforeEach setup failed to acquire instance of test class");
114 return;
115 }
116
117 TestRun.prepareForNextTest();
118 TestRun.enterNoMockingZone();
119
120 try {
121 savePointForTest = new SavePoint();
122 createInstancesForTestedFieldsBeforeSetup(testInstance);
123
124 if (testClass == null) {
125 initContext = new ParamValueInitContext(null, null, null,
126 "@BeforeEach setup failed to acquire Class<?> of test");
127 return;
128 }
129
130 Method beforeEachMethod = Utilities.getAnnotatedDeclaredMethod(testClass, BeforeEach.class);
131 if (beforeEachMethod != null) {
132 initContext = new ParamValueInitContext(testInstance, testClass, beforeEachMethod, null);
133 parameterValues = createInstancesForAnnotatedParameters(testInstance, beforeEachMethod, null);
134 }
135 } finally {
136 TestRun.exitNoMockingZone();
137 }
138 }
139
140 @Override
141 public void beforeTestExecution(@NonNull ExtensionContext context) {
142 Class<?> testClass = context.getTestClass().orElse(null);
143 Method testMethod = context.getTestMethod().orElse(null);
144 Object testInstance = context.getTestInstance().orElse(null);
145
146 if (testMethod == null || testInstance == null) {
147 initContext = new ParamValueInitContext(testInstance, testClass, testMethod,
148 "@Test failed to acquire instance of test class, or target method");
149 return;
150 }
151
152 TestRun.enterNoMockingZone();
153
154 try {
155 savePointForTestMethod = new SavePoint();
156 createInstancesForTestedFieldsFromBaseClasses(testInstance);
157 initContext = new ParamValueInitContext(testInstance, testClass, testMethod, null);
158 parameterValues = createInstancesForAnnotatedParameters(testInstance, testMethod, null);
159 createInstancesForTestedFields(testInstance);
160 } finally {
161 TestRun.exitNoMockingZone();
162 }
163
164 TestRun.setRunningIndividualTest(testInstance);
165 }
166
167 @Override
168 public boolean supportsParameter(@NonNull ParameterContext parameterContext,
169 @NonNull ExtensionContext extensionContext) {
170 return parameterContext.isAnnotated(Tested.class) || parameterContext.isAnnotated(Mocked.class)
171 || parameterContext.isAnnotated(Injectable.class) || parameterContext.isAnnotated(Capturing.class);
172 }
173
174 @Override
175 public Object resolveParameter(@NonNull ParameterContext parameterContext,
176 @NonNull ExtensionContext extensionContext) {
177 int parameterIndex = parameterContext.getIndex();
178 if (parameterValues == null) {
179 String warning = initContext.warning;
180 StringBuilder exceptionMessage = new StringBuilder(
181 "JMockit failed to provide parameters to JUnit 5 ParameterResolver.");
182 if (warning != null) {
183 exceptionMessage.append("\nAdditional info: ").append(warning);
184 }
185 exceptionMessage.append("\n - Class: ").append(initContext.displayClass());
186 exceptionMessage.append("\n - Method: ").append(initContext.displayMethod());
187 throw new IllegalStateException(exceptionMessage.toString());
188 }
189 return parameterValues[parameterIndex];
190 }
191
192 @Override
193 public void handleTestExecutionException(@NonNull ExtensionContext context, @NonNull Throwable throwable)
194 throws Throwable {
195 thrownByTest = throwable;
196 throw throwable;
197 }
198
199 @Override
200 public void afterTestExecution(@NonNull ExtensionContext context) {
201 if (savePointForTestMethod != null) {
202 TestRun.enterNoMockingZone();
203
204 try {
205 savePointForTestMethod.rollback();
206 savePointForTestMethod = null;
207
208 if (thrownByTest != null) {
209 filterStackTrace(thrownByTest);
210 }
211
212 Error expectationsFailure = RecordAndReplayExecution.endCurrentReplayIfAny();
213 clearTestedObjectsIfAny();
214
215 if (expectationsFailure != null) {
216 filterStackTrace(expectationsFailure);
217 throw expectationsFailure;
218 }
219 } finally {
220 TestRun.finishCurrentTestExecution();
221 TestRun.exitNoMockingZone();
222 }
223 }
224 }
225
226 @Override
227 public void afterEach(@NonNull ExtensionContext context) {
228 if (savePointForTest != null) {
229 savePointForTest.rollback();
230 savePointForTest = null;
231 }
232 }
233
234 @Override
235 public void afterAll(@NonNull ExtensionContext context) {
236 if (savePointForTestClass != null && isRegularTestClass(context)) {
237 savePointForTestClass.rollback();
238 savePointForTestClass = null;
239
240 clearFieldTypeRedefinitions();
241 TestRun.setCurrentTestClass(null);
242 }
243 }
244
245 private static class ParamValueInitContext {
246 private final Object instance;
247 private final Class<?> clazz;
248 private final Method method;
249 private final String warning;
250
251 ParamValueInitContext(Object instance, Class<?> clazz, Method method, String warning) {
252 this.instance = instance;
253 this.clazz = clazz;
254 this.method = method;
255 this.warning = warning;
256 }
257
258 boolean isBeforeAllMethod() {
259 return method.getDeclaredAnnotation(BeforeAll.class) != null;
260 }
261
262 boolean isBeforeEachMethod() {
263 return method.getDeclaredAnnotation(BeforeEach.class) != null;
264 }
265
266 String displayClass() {
267 if (clazz == null) {
268 return "<no class reference>";
269 }
270 return clazz.getName();
271 }
272
273 String displayMethod() {
274 if (method == null) {
275 return "<no method reference>";
276 }
277 String methodPrefix = isBeforeAllMethod() ? "@BeforeAll " : isBeforeEachMethod() ? "@BeforeEach " : "";
278 String args = Arrays.stream(method.getParameterTypes()).map(Class::getName)
279 .collect(Collectors.joining(", "));
280 return methodPrefix + method.getName() + "(" + args + ")";
281 }
282
283 @Override
284 public String toString() {
285 return "ParamContext{" + "hasInstance=" + (instance == null ? "false" : "true") + ", class=" + clazz
286 + ", method=" + method + '}';
287 }
288 }
289 }