View Javadoc
1   /*
2    * MIT License
3    * Copyright (c) 2006-2025 JMockit developers
4    * See LICENSE file for full license text.
5    */
6   package mockit.integration.junit5;
7   
8   import edu.umd.cs.findbugs.annotations.NonNull;
9   import edu.umd.cs.findbugs.annotations.Nullable;
10  
11  import java.lang.reflect.Method;
12  import java.util.Arrays;
13  import java.util.stream.Collectors;
14  
15  import mockit.Capturing;
16  import mockit.Injectable;
17  import mockit.Mocked;
18  import mockit.Tested;
19  import mockit.integration.TestRunnerDecorator;
20  import mockit.internal.expectations.RecordAndReplayExecution;
21  import mockit.internal.faking.FakeStates;
22  import mockit.internal.state.SavePoint;
23  import mockit.internal.state.TestRun;
24  import mockit.internal.util.StackTrace;
25  import mockit.internal.util.Utilities;
26  
27  import org.junit.jupiter.api.BeforeAll;
28  import org.junit.jupiter.api.BeforeEach;
29  import org.junit.jupiter.api.Nested;
30  import org.junit.jupiter.api.extension.AfterAllCallback;
31  import org.junit.jupiter.api.extension.AfterEachCallback;
32  import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
33  import org.junit.jupiter.api.extension.BeforeAllCallback;
34  import org.junit.jupiter.api.extension.BeforeEachCallback;
35  import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
36  import org.junit.jupiter.api.extension.ExtensionContext;
37  import org.junit.jupiter.api.extension.ParameterContext;
38  import org.junit.jupiter.api.extension.ParameterResolver;
39  import org.junit.jupiter.api.extension.TestExecutionExceptionHandler;
40  import org.junit.jupiter.api.extension.TestInstancePostProcessor;
41  import org.opentest4j.TestAbortedException;
42  
43  public final class JMockitExtension extends TestRunnerDecorator implements BeforeAllCallback, AfterAllCallback,
44          TestInstancePostProcessor, BeforeEachCallback, AfterEachCallback, BeforeTestExecutionCallback,
45          AfterTestExecutionCallback, ParameterResolver, TestExecutionExceptionHandler {
46      @Nullable
47      private SavePoint savePointForTestClass;
48      @Nullable
49      private SavePoint savePointForTest;
50      @Nullable
51      private SavePoint savePointForTestMethod;
52      @Nullable
53      private Throwable thrownByTest;
54      private Object[] parameterValues;
55      private ParamValueInitContext initContext = new ParamValueInitContext(null, null, null,
56              "No callbacks have been processed, preventing parameter population");
57  
58      @Override
59      public void beforeAll(@NonNull ExtensionContext context) {
60          if (!isRegularTestClass(context)) {
61              return;
62          }
63  
64          @Nullable
65          Class<?> testClass = context.getTestClass().orElse(null);
66          savePointForTestClass = new SavePoint();
67          // Ensure JMockit state and test class logic is handled before any test instance is created
68          if (testClass != null) {
69              updateTestClassState(null, testClass);
70          }
71  
72          if (testClass == null) {
73              initContext = new ParamValueInitContext(null, null, null,
74                      "@BeforeAll setup failed to acquire 'Class' of test");
75              return;
76          }
77  
78          // @BeforeAll can be used on instance methods depending on @TestInstance(PER_CLASS) usage
79          Object testInstance = context.getTestInstance().orElse(null);
80          Method beforeAllMethod = Utilities.getAnnotatedDeclaredMethod(testClass, BeforeAll.class);
81          if (testInstance == null) {
82              initContext = new ParamValueInitContext(null, testClass, beforeAllMethod,
83                      "@BeforeAll setup failed to acquire instance of test class");
84              return;
85          }
86  
87          if (beforeAllMethod != null) {
88              initContext = new ParamValueInitContext(testInstance, testClass, beforeAllMethod, null);
89              parameterValues = createInstancesForAnnotatedParameters(testInstance, beforeAllMethod, null);
90          }
91      }
92  
93      private static boolean isRegularTestClass(@NonNull ExtensionContext context) {
94          Class<?> testClass = context.getTestClass().orElse(null);
95          return testClass != null && !testClass.isAnnotationPresent(Nested.class);
96      }
97  
98      @Override
99      public void postProcessTestInstance(@NonNull Object testInstance, @NonNull ExtensionContext context) {
100         if (!isRegularTestClass(context)) {
101             return;
102         }
103 
104         TestRun.enterNoMockingZone();
105 
106         try {
107             handleMockFieldsForWholeTestClass(testInstance);
108         } finally {
109             TestRun.exitNoMockingZone();
110         }
111 
112         TestRun.setRunningIndividualTest(testInstance);
113     }
114 
115     @Override
116     public void beforeEach(@NonNull ExtensionContext context) {
117         Object testInstance = context.getTestInstance().orElse(null);
118         Class<?> testClass = context.getTestClass().orElse(null);
119         if (testInstance == null) {
120             initContext = new ParamValueInitContext(null, null, null,
121                     "@BeforeEach setup failed to acquire instance of test class");
122             return;
123         }
124 
125         TestRun.prepareForNextTest();
126         TestRun.enterNoMockingZone();
127 
128         try {
129             savePointForTest = new SavePoint();
130             createInstancesForTestedFieldsBeforeSetup(testInstance);
131 
132             if (testClass == null) {
133                 initContext = new ParamValueInitContext(null, null, null,
134                         "@BeforeEach setup failed to acquire Class<?> of test");
135                 return;
136             }
137 
138             Method beforeEachMethod = Utilities.getAnnotatedDeclaredMethod(testClass, BeforeEach.class);
139             if (beforeEachMethod != null) {
140                 initContext = new ParamValueInitContext(testInstance, testClass, beforeEachMethod, null);
141                 parameterValues = createInstancesForAnnotatedParameters(testInstance, beforeEachMethod, null);
142             }
143         } finally {
144             TestRun.exitNoMockingZone();
145         }
146     }
147 
148     @Override
149     public void beforeTestExecution(@NonNull ExtensionContext context) {
150         Class<?> testClass = context.getTestClass().orElse(null);
151         Method testMethod = context.getTestMethod().orElse(null);
152         Object testInstance = context.getTestInstance().orElse(null);
153 
154         if (testMethod == null || testInstance == null) {
155             initContext = new ParamValueInitContext(testInstance, testClass, testMethod,
156                     "@Test failed to acquire instance of test class, or target method");
157             return;
158         }
159 
160         TestRun.enterNoMockingZone();
161 
162         try {
163             savePointForTestMethod = new SavePoint();
164             createInstancesForTestedFieldsFromBaseClasses(testInstance);
165             initContext = new ParamValueInitContext(testInstance, testClass, testMethod, null);
166             parameterValues = createInstancesForAnnotatedParameters(testInstance, testMethod, null);
167             createInstancesForTestedFields(testInstance);
168         } catch (Throwable e) {
169             if (isExpectedException(context, e)) {
170                 throw new TestAbortedException("Expected exception occurred in setup: " + e.getMessage());
171             }
172             throw e;
173         } finally {
174             TestRun.exitNoMockingZone();
175         }
176 
177         TestRun.setRunningIndividualTest(testInstance);
178     }
179 
180     @Override
181     public boolean supportsParameter(@NonNull ParameterContext parameterContext,
182             @NonNull ExtensionContext extensionContext) {
183         return parameterContext.isAnnotated(Tested.class) || parameterContext.isAnnotated(Mocked.class)
184                 || parameterContext.isAnnotated(Injectable.class) || parameterContext.isAnnotated(Capturing.class);
185     }
186 
187     @Override
188     public Object resolveParameter(@NonNull ParameterContext parameterContext,
189             @NonNull ExtensionContext extensionContext) {
190         int parameterIndex = parameterContext.getIndex();
191         if (parameterValues == null) {
192             String warning = initContext.warning;
193             StringBuilder exceptionMessage = new StringBuilder(
194                     "JMockit failed to provide parameters to JUnit 5 ParameterResolver.");
195             if (warning != null) {
196                 exceptionMessage.append("\nAdditional info: ").append(warning);
197             }
198             exceptionMessage.append("\n - Class: ").append(initContext.displayClass());
199             exceptionMessage.append("\n - Method: ").append(initContext.displayMethod());
200             throw new IllegalStateException(exceptionMessage.toString());
201         }
202         return parameterValues[parameterIndex];
203     }
204 
205     @Override
206     public void handleTestExecutionException(@NonNull ExtensionContext context, @NonNull Throwable throwable)
207             throws Throwable {
208         if (isExpectedException(context, throwable)) {
209             // Expected exception was thrown, suppress it (test passes)
210             return;
211         }
212 
213         thrownByTest = throwable;
214         throw throwable;
215     }
216 
217     @Override
218     public void afterTestExecution(@NonNull ExtensionContext context) {
219         if (savePointForTestMethod == null) {
220             return;
221         }
222 
223         TestRun.enterNoMockingZone();
224 
225         try {
226             savePointForTestMethod.rollback();
227             savePointForTestMethod = null;
228 
229             if (thrownByTest != null) {
230                 StackTrace.filterStackTrace(thrownByTest);
231             }
232 
233             Error expectationsFailure = RecordAndReplayExecution.endCurrentReplayIfAny();
234             FakeStates fakeStates = TestRun.getFakeStates();
235 
236             if (expectationsFailure != null && isExpectedException(context, expectationsFailure)) {
237                 // Expected JMockit error was thrown, suppress it (test passes)
238                 clearTestedObjectsIfAny();
239                 return;
240             }
241 
242             fakeStates.verifyMissingInvocations();
243             clearTestedObjectsIfAny();
244 
245             if (expectationsFailure != null) {
246                 StackTrace.filterStackTrace(expectationsFailure);
247                 fakeStates.resetExpectations();
248                 throw expectationsFailure;
249             }
250             fakeStates.resetExpectations();
251         } finally {
252             TestRun.finishCurrentTestExecution();
253             TestRun.exitNoMockingZone();
254         }
255     }
256 
257     @Override
258     public void afterEach(@NonNull ExtensionContext context) {
259         if (savePointForTest != null) {
260             savePointForTest.rollback();
261             savePointForTest = null;
262         }
263     }
264 
265     @Override
266     public void afterAll(@NonNull ExtensionContext context) {
267         if (savePointForTestClass != null && isRegularTestClass(context)) {
268             savePointForTestClass.rollback();
269             savePointForTestClass = null;
270 
271             clearFieldTypeRedefinitions();
272             TestRun.setCurrentTestClass(null);
273         }
274     }
275 
276     private static class ParamValueInitContext {
277         private final Object instance;
278         private final Class<?> clazz;
279         private final Method method;
280         private final String warning;
281 
282         ParamValueInitContext(Object instance, Class<?> clazz, Method method, String warning) {
283             this.instance = instance;
284             this.clazz = clazz;
285             this.method = method;
286             this.warning = warning;
287         }
288 
289         boolean isBeforeAllMethod() {
290             return method != null && method.getDeclaredAnnotation(BeforeAll.class) != null;
291         }
292 
293         boolean isBeforeEachMethod() {
294             return method != null && method.getDeclaredAnnotation(BeforeEach.class) != null;
295         }
296 
297         String displayClass() {
298             return clazz == null ? "<no class reference>" : clazz.getName();
299         }
300 
301         String displayMethod() {
302             if (method == null) {
303                 return "<no method reference>";
304             }
305             String methodPrefix = isBeforeAllMethod() ? "@BeforeAll " : isBeforeEachMethod() ? "@BeforeEach " : "";
306             String args = Arrays.stream(method.getParameterTypes()).map(Class::getName)
307                     .collect(Collectors.joining(", "));
308             return methodPrefix + method.getName() + "(" + args + ")";
309         }
310 
311         @Override
312         public String toString() {
313             return "ParamContext{hasInstance=" + (instance == null ? "false" : "true") + ", class=" + clazz
314                     + ", method=" + method + ", warning=" + warning + "}";
315         }
316     }
317 
318     private static boolean isExpectedException(@NonNull ExtensionContext context, @NonNull Throwable throwable) {
319         Method testMethod = context.getTestMethod().orElse(null);
320         ExpectedException expectedException = testMethod != null ? testMethod.getAnnotation(ExpectedException.class)
321                 : null;
322 
323         if (expectedException == null) {
324             return false;
325         }
326 
327         return expectedException.value().isInstance(throwable) && matchesExpectedMessages(throwable, expectedException);
328     }
329 
330     private static boolean matchesExpectedMessages(Throwable throwable, ExpectedException expectedException) {
331         String[] expectedMessages = expectedException.expectedMessages();
332         if (expectedMessages.length == 0) {
333             // No message requirement
334             return true;
335         }
336 
337         String actualMessage = throwable.getMessage();
338         if (actualMessage == null) {
339             return false;
340         }
341 
342         boolean contains = expectedException.messageContains();
343         for (String expected : expectedMessages) {
344             if (contains) {
345                 if (actualMessage.contains(expected)) {
346                     return true;
347                 }
348             } else if (actualMessage.equals(expected)) {
349                 return true;
350             }
351         }
352         return false;
353     }
354 
355 }