1
2
3
4
5
6 package mockit.integration.junit5;
7
8 import edu.umd.cs.findbugs.annotations.NonNull;
9 import edu.umd.cs.findbugs.annotations.Nullable;
10
11 import java.lang.reflect.Method;
12 import java.util.Arrays;
13 import java.util.stream.Collectors;
14
15 import mockit.Capturing;
16 import mockit.Injectable;
17 import mockit.Mocked;
18 import mockit.Tested;
19 import mockit.integration.TestRunnerDecorator;
20 import mockit.internal.expectations.RecordAndReplayExecution;
21 import mockit.internal.faking.FakeStates;
22 import mockit.internal.state.SavePoint;
23 import mockit.internal.state.TestRun;
24 import mockit.internal.util.StackTrace;
25 import mockit.internal.util.Utilities;
26
27 import org.junit.jupiter.api.BeforeAll;
28 import org.junit.jupiter.api.BeforeEach;
29 import org.junit.jupiter.api.Nested;
30 import org.junit.jupiter.api.extension.AfterAllCallback;
31 import org.junit.jupiter.api.extension.AfterEachCallback;
32 import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
33 import org.junit.jupiter.api.extension.BeforeAllCallback;
34 import org.junit.jupiter.api.extension.BeforeEachCallback;
35 import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
36 import org.junit.jupiter.api.extension.ExtensionContext;
37 import org.junit.jupiter.api.extension.ParameterContext;
38 import org.junit.jupiter.api.extension.ParameterResolver;
39 import org.junit.jupiter.api.extension.TestExecutionExceptionHandler;
40 import org.junit.jupiter.api.extension.TestInstancePostProcessor;
41 import org.opentest4j.TestAbortedException;
42
43 public final class JMockitExtension extends TestRunnerDecorator implements BeforeAllCallback, AfterAllCallback,
44 TestInstancePostProcessor, BeforeEachCallback, AfterEachCallback, BeforeTestExecutionCallback,
45 AfterTestExecutionCallback, ParameterResolver, TestExecutionExceptionHandler {
46 @Nullable
47 private SavePoint savePointForTestClass;
48 @Nullable
49 private SavePoint savePointForTest;
50 @Nullable
51 private SavePoint savePointForTestMethod;
52 @Nullable
53 private Throwable thrownByTest;
54 private Object[] parameterValues;
55 private ParamValueInitContext initContext = new ParamValueInitContext(null, null, null,
56 "No callbacks have been processed, preventing parameter population");
57
58 @Override
59 public void beforeAll(@NonNull ExtensionContext context) {
60 if (!isRegularTestClass(context)) {
61 return;
62 }
63
64 @Nullable
65 Class<?> testClass = context.getTestClass().orElse(null);
66 savePointForTestClass = new SavePoint();
67
68 if (testClass != null) {
69 updateTestClassState(null, testClass);
70 }
71
72 if (testClass == null) {
73 initContext = new ParamValueInitContext(null, null, null,
74 "@BeforeAll setup failed to acquire 'Class' of test");
75 return;
76 }
77
78
79 Object testInstance = context.getTestInstance().orElse(null);
80 Method beforeAllMethod = Utilities.getAnnotatedDeclaredMethod(testClass, BeforeAll.class);
81 if (testInstance == null) {
82 initContext = new ParamValueInitContext(null, testClass, beforeAllMethod,
83 "@BeforeAll setup failed to acquire instance of test class");
84 return;
85 }
86
87 if (beforeAllMethod != null) {
88 initContext = new ParamValueInitContext(testInstance, testClass, beforeAllMethod, null);
89 parameterValues = createInstancesForAnnotatedParameters(testInstance, beforeAllMethod, null);
90 }
91 }
92
93 private static boolean isRegularTestClass(@NonNull ExtensionContext context) {
94 Class<?> testClass = context.getTestClass().orElse(null);
95 return testClass != null && !testClass.isAnnotationPresent(Nested.class);
96 }
97
98 @Override
99 public void postProcessTestInstance(@NonNull Object testInstance, @NonNull ExtensionContext context) {
100 if (!isRegularTestClass(context)) {
101 return;
102 }
103
104 TestRun.enterNoMockingZone();
105
106 try {
107 handleMockFieldsForWholeTestClass(testInstance);
108 } finally {
109 TestRun.exitNoMockingZone();
110 }
111
112 TestRun.setRunningIndividualTest(testInstance);
113 }
114
115 @Override
116 public void beforeEach(@NonNull ExtensionContext context) {
117 Object testInstance = context.getTestInstance().orElse(null);
118 Class<?> testClass = context.getTestClass().orElse(null);
119 if (testInstance == null) {
120 initContext = new ParamValueInitContext(null, null, null,
121 "@BeforeEach setup failed to acquire instance of test class");
122 return;
123 }
124
125 TestRun.prepareForNextTest();
126 TestRun.enterNoMockingZone();
127
128 try {
129 savePointForTest = new SavePoint();
130 createInstancesForTestedFieldsBeforeSetup(testInstance);
131
132 if (testClass == null) {
133 initContext = new ParamValueInitContext(null, null, null,
134 "@BeforeEach setup failed to acquire Class<?> of test");
135 return;
136 }
137
138 Method beforeEachMethod = Utilities.getAnnotatedDeclaredMethod(testClass, BeforeEach.class);
139 if (beforeEachMethod != null) {
140 initContext = new ParamValueInitContext(testInstance, testClass, beforeEachMethod, null);
141 parameterValues = createInstancesForAnnotatedParameters(testInstance, beforeEachMethod, null);
142 }
143 } finally {
144 TestRun.exitNoMockingZone();
145 }
146 }
147
148 @Override
149 public void beforeTestExecution(@NonNull ExtensionContext context) {
150 Class<?> testClass = context.getTestClass().orElse(null);
151 Method testMethod = context.getTestMethod().orElse(null);
152 Object testInstance = context.getTestInstance().orElse(null);
153
154 if (testMethod == null || testInstance == null) {
155 initContext = new ParamValueInitContext(testInstance, testClass, testMethod,
156 "@Test failed to acquire instance of test class, or target method");
157 return;
158 }
159
160 TestRun.enterNoMockingZone();
161
162 try {
163 savePointForTestMethod = new SavePoint();
164 createInstancesForTestedFieldsFromBaseClasses(testInstance);
165 initContext = new ParamValueInitContext(testInstance, testClass, testMethod, null);
166 parameterValues = createInstancesForAnnotatedParameters(testInstance, testMethod, null);
167 createInstancesForTestedFields(testInstance);
168 } catch (Throwable e) {
169 if (isExpectedException(context, e)) {
170 throw new TestAbortedException("Expected exception occurred in setup: " + e.getMessage());
171 }
172 throw e;
173 } finally {
174 TestRun.exitNoMockingZone();
175 }
176
177 TestRun.setRunningIndividualTest(testInstance);
178 }
179
180 @Override
181 public boolean supportsParameter(@NonNull ParameterContext parameterContext,
182 @NonNull ExtensionContext extensionContext) {
183 return parameterContext.isAnnotated(Tested.class) || parameterContext.isAnnotated(Mocked.class)
184 || parameterContext.isAnnotated(Injectable.class) || parameterContext.isAnnotated(Capturing.class);
185 }
186
187 @Override
188 public Object resolveParameter(@NonNull ParameterContext parameterContext,
189 @NonNull ExtensionContext extensionContext) {
190 int parameterIndex = parameterContext.getIndex();
191 if (parameterValues == null) {
192 String warning = initContext.warning;
193 StringBuilder exceptionMessage = new StringBuilder(
194 "JMockit failed to provide parameters to JUnit 5 ParameterResolver.");
195 if (warning != null) {
196 exceptionMessage.append("\nAdditional info: ").append(warning);
197 }
198 exceptionMessage.append("\n - Class: ").append(initContext.displayClass());
199 exceptionMessage.append("\n - Method: ").append(initContext.displayMethod());
200 throw new IllegalStateException(exceptionMessage.toString());
201 }
202 return parameterValues[parameterIndex];
203 }
204
205 @Override
206 public void handleTestExecutionException(@NonNull ExtensionContext context, @NonNull Throwable throwable)
207 throws Throwable {
208 if (isExpectedException(context, throwable)) {
209
210 return;
211 }
212
213 thrownByTest = throwable;
214 throw throwable;
215 }
216
217 @Override
218 public void afterTestExecution(@NonNull ExtensionContext context) {
219 if (savePointForTestMethod == null) {
220 return;
221 }
222
223 TestRun.enterNoMockingZone();
224
225 try {
226 savePointForTestMethod.rollback();
227 savePointForTestMethod = null;
228
229 if (thrownByTest != null) {
230 StackTrace.filterStackTrace(thrownByTest);
231 }
232
233 Error expectationsFailure = RecordAndReplayExecution.endCurrentReplayIfAny();
234 FakeStates fakeStates = TestRun.getFakeStates();
235
236 if (expectationsFailure != null && isExpectedException(context, expectationsFailure)) {
237
238 clearTestedObjectsIfAny();
239 return;
240 }
241
242 fakeStates.verifyMissingInvocations();
243 clearTestedObjectsIfAny();
244
245 if (expectationsFailure != null) {
246 StackTrace.filterStackTrace(expectationsFailure);
247 fakeStates.resetExpectations();
248 throw expectationsFailure;
249 }
250 fakeStates.resetExpectations();
251 } finally {
252 TestRun.finishCurrentTestExecution();
253 TestRun.exitNoMockingZone();
254 }
255 }
256
257 @Override
258 public void afterEach(@NonNull ExtensionContext context) {
259 if (savePointForTest != null) {
260 savePointForTest.rollback();
261 savePointForTest = null;
262 }
263 }
264
265 @Override
266 public void afterAll(@NonNull ExtensionContext context) {
267 if (savePointForTestClass != null && isRegularTestClass(context)) {
268 savePointForTestClass.rollback();
269 savePointForTestClass = null;
270
271 clearFieldTypeRedefinitions();
272 TestRun.setCurrentTestClass(null);
273 }
274 }
275
276 private static class ParamValueInitContext {
277 private final Object instance;
278 private final Class<?> clazz;
279 private final Method method;
280 private final String warning;
281
282 ParamValueInitContext(Object instance, Class<?> clazz, Method method, String warning) {
283 this.instance = instance;
284 this.clazz = clazz;
285 this.method = method;
286 this.warning = warning;
287 }
288
289 boolean isBeforeAllMethod() {
290 return method != null && method.getDeclaredAnnotation(BeforeAll.class) != null;
291 }
292
293 boolean isBeforeEachMethod() {
294 return method != null && method.getDeclaredAnnotation(BeforeEach.class) != null;
295 }
296
297 String displayClass() {
298 return clazz == null ? "<no class reference>" : clazz.getName();
299 }
300
301 String displayMethod() {
302 if (method == null) {
303 return "<no method reference>";
304 }
305 String methodPrefix = isBeforeAllMethod() ? "@BeforeAll " : isBeforeEachMethod() ? "@BeforeEach " : "";
306 String args = Arrays.stream(method.getParameterTypes()).map(Class::getName)
307 .collect(Collectors.joining(", "));
308 return methodPrefix + method.getName() + "(" + args + ")";
309 }
310
311 @Override
312 public String toString() {
313 return "ParamContext{hasInstance=" + (instance == null ? "false" : "true") + ", class=" + clazz
314 + ", method=" + method + ", warning=" + warning + "}";
315 }
316 }
317
318 private static boolean isExpectedException(@NonNull ExtensionContext context, @NonNull Throwable throwable) {
319 Method testMethod = context.getTestMethod().orElse(null);
320 ExpectedException expectedException = testMethod != null ? testMethod.getAnnotation(ExpectedException.class)
321 : null;
322
323 if (expectedException == null) {
324 return false;
325 }
326
327 return expectedException.value().isInstance(throwable) && matchesExpectedMessages(throwable, expectedException);
328 }
329
330 private static boolean matchesExpectedMessages(Throwable throwable, ExpectedException expectedException) {
331 String[] expectedMessages = expectedException.expectedMessages();
332 if (expectedMessages.length == 0) {
333
334 return true;
335 }
336
337 String actualMessage = throwable.getMessage();
338 if (actualMessage == null) {
339 return false;
340 }
341
342 boolean contains = expectedException.messageContains();
343 for (String expected : expectedMessages) {
344 if (contains) {
345 if (actualMessage.contains(expected)) {
346 return true;
347 }
348 } else if (actualMessage.equals(expected)) {
349 return true;
350 }
351 }
352 return false;
353 }
354
355 }