View Javadoc
1   /*
2    * Copyright (c) 2006 JMockit developers
3    * This file is subject to the terms of the MIT license (see LICENSE.txt).
4    */
5   package mockit.integration.junit5;
6   
7   import static mockit.internal.util.StackTrace.filterStackTrace;
8   
9   import edu.umd.cs.findbugs.annotations.NonNull;
10  import edu.umd.cs.findbugs.annotations.Nullable;
11  
12  import java.lang.reflect.Method;
13  import java.util.Arrays;
14  import java.util.stream.Collectors;
15  
16  import mockit.Capturing;
17  import mockit.Injectable;
18  import mockit.Mocked;
19  import mockit.Tested;
20  import mockit.integration.TestRunnerDecorator;
21  import mockit.internal.expectations.RecordAndReplayExecution;
22  import mockit.internal.state.SavePoint;
23  import mockit.internal.state.TestRun;
24  import mockit.internal.util.Utilities;
25  
26  import org.junit.jupiter.api.BeforeAll;
27  import org.junit.jupiter.api.BeforeEach;
28  import org.junit.jupiter.api.Nested;
29  import org.junit.jupiter.api.extension.AfterAllCallback;
30  import org.junit.jupiter.api.extension.AfterEachCallback;
31  import org.junit.jupiter.api.extension.AfterTestExecutionCallback;
32  import org.junit.jupiter.api.extension.BeforeAllCallback;
33  import org.junit.jupiter.api.extension.BeforeEachCallback;
34  import org.junit.jupiter.api.extension.BeforeTestExecutionCallback;
35  import org.junit.jupiter.api.extension.ExtensionContext;
36  import org.junit.jupiter.api.extension.ParameterContext;
37  import org.junit.jupiter.api.extension.ParameterResolver;
38  import org.junit.jupiter.api.extension.TestExecutionExceptionHandler;
39  import org.junit.jupiter.api.extension.TestInstancePostProcessor;
40  
41  @SuppressWarnings("Since15")
42  public final class JMockitExtension extends TestRunnerDecorator implements BeforeAllCallback, AfterAllCallback,
43          TestInstancePostProcessor, BeforeEachCallback, AfterEachCallback, BeforeTestExecutionCallback,
44          AfterTestExecutionCallback, ParameterResolver, TestExecutionExceptionHandler {
45      @Nullable
46      private SavePoint savePointForTestClass;
47      @Nullable
48      private SavePoint savePointForTest;
49      @Nullable
50      private SavePoint savePointForTestMethod;
51      @Nullable
52      private Throwable thrownByTest;
53      private Object[] parameterValues;
54      private ParamValueInitContext initContext = new ParamValueInitContext(null, null, null,
55              "No callbacks have been processed, preventing parameter population");
56  
57      @Override
58      public void beforeAll(@NonNull ExtensionContext context) {
59          if (isRegularTestClass(context)) {
60              @Nullable
61              Class<?> testClass = context.getTestClass().orElse(null);
62              savePointForTestClass = new SavePoint();
63              TestRun.setCurrentTestClass(testClass);
64  
65              if (testClass == null) {
66                  initContext = new ParamValueInitContext(null, null, null,
67                          "@BeforeAll setup failed to acquire 'Class' of test");
68                  return;
69              }
70  
71              // @BeforeAll can be used on instance methods depending on @TestInstance(PER_CLASS) usage
72              Object testInstance = context.getTestInstance().orElse(null);
73              Method beforeAllMethod = Utilities.getAnnotatedDeclaredMethod(testClass, BeforeAll.class);
74              if (testInstance == null) {
75                  initContext = new ParamValueInitContext(null, testClass, beforeAllMethod,
76                          "@BeforeAll setup failed to acquire instance of test class");
77                  return;
78              }
79  
80              if (beforeAllMethod != null) {
81                  initContext = new ParamValueInitContext(testInstance, testClass, beforeAllMethod, null);
82                  parameterValues = createInstancesForAnnotatedParameters(testInstance, beforeAllMethod, null);
83              }
84          }
85      }
86  
87      private static boolean isRegularTestClass(@NonNull ExtensionContext context) {
88          Class<?> testClass = context.getTestClass().orElse(null);
89          return testClass != null && !testClass.isAnnotationPresent(Nested.class);
90      }
91  
92      @Override
93      public void postProcessTestInstance(@NonNull Object testInstance, @NonNull ExtensionContext context) {
94          if (isRegularTestClass(context)) {
95              TestRun.enterNoMockingZone();
96  
97              try {
98                  handleMockFieldsForWholeTestClass(testInstance);
99              } finally {
100                 TestRun.exitNoMockingZone();
101             }
102 
103             TestRun.setRunningIndividualTest(testInstance);
104         }
105     }
106 
107     @Override
108     public void beforeEach(@NonNull ExtensionContext context) {
109         Object testInstance = context.getTestInstance().orElse(null);
110         Class<?> testClass = context.getTestClass().orElse(null);
111         if (testInstance == null) {
112             initContext = new ParamValueInitContext(null, null, null,
113                     "@BeforeEach setup failed to acquire instance of test class");
114             return;
115         }
116 
117         TestRun.prepareForNextTest();
118         TestRun.enterNoMockingZone();
119 
120         try {
121             savePointForTest = new SavePoint();
122             createInstancesForTestedFieldsBeforeSetup(testInstance);
123 
124             if (testClass == null) {
125                 initContext = new ParamValueInitContext(null, null, null,
126                         "@BeforeEach setup failed to acquire Class<?> of test");
127                 return;
128             }
129 
130             Method beforeEachMethod = Utilities.getAnnotatedDeclaredMethod(testClass, BeforeEach.class);
131             if (beforeEachMethod != null) {
132                 initContext = new ParamValueInitContext(testInstance, testClass, beforeEachMethod, null);
133                 parameterValues = createInstancesForAnnotatedParameters(testInstance, beforeEachMethod, null);
134             }
135         } finally {
136             TestRun.exitNoMockingZone();
137         }
138     }
139 
140     @Override
141     public void beforeTestExecution(@NonNull ExtensionContext context) {
142         Class<?> testClass = context.getTestClass().orElse(null);
143         Method testMethod = context.getTestMethod().orElse(null);
144         Object testInstance = context.getTestInstance().orElse(null);
145 
146         if (testMethod == null || testInstance == null) {
147             initContext = new ParamValueInitContext(testInstance, testClass, testMethod,
148                     "@Test failed to acquire instance of test class, or target method");
149             return;
150         }
151 
152         TestRun.enterNoMockingZone();
153 
154         try {
155             savePointForTestMethod = new SavePoint();
156             createInstancesForTestedFieldsFromBaseClasses(testInstance);
157             initContext = new ParamValueInitContext(testInstance, testClass, testMethod, null);
158             parameterValues = createInstancesForAnnotatedParameters(testInstance, testMethod, null);
159             createInstancesForTestedFields(testInstance);
160         } finally {
161             TestRun.exitNoMockingZone();
162         }
163 
164         TestRun.setRunningIndividualTest(testInstance);
165     }
166 
167     @Override
168     public boolean supportsParameter(@NonNull ParameterContext parameterContext,
169             @NonNull ExtensionContext extensionContext) {
170         return parameterContext.isAnnotated(Tested.class) || parameterContext.isAnnotated(Mocked.class)
171                 || parameterContext.isAnnotated(Injectable.class) || parameterContext.isAnnotated(Capturing.class);
172     }
173 
174     @Override
175     public Object resolveParameter(@NonNull ParameterContext parameterContext,
176             @NonNull ExtensionContext extensionContext) {
177         int parameterIndex = parameterContext.getIndex();
178         if (parameterValues == null) {
179             String warning = initContext.warning;
180             StringBuilder exceptionMessage = new StringBuilder(
181                     "JMockit failed to provide parameters to JUnit 5 ParameterResolver.");
182             if (warning != null) {
183                 exceptionMessage.append("\nAdditional info: ").append(warning);
184             }
185             exceptionMessage.append("\n - Class: ").append(initContext.displayClass());
186             exceptionMessage.append("\n - Method: ").append(initContext.displayMethod());
187             throw new IllegalStateException(exceptionMessage.toString());
188         }
189         return parameterValues[parameterIndex];
190     }
191 
192     @Override
193     public void handleTestExecutionException(@NonNull ExtensionContext context, @NonNull Throwable throwable)
194             throws Throwable {
195         thrownByTest = throwable;
196         throw throwable;
197     }
198 
199     @Override
200     public void afterTestExecution(@NonNull ExtensionContext context) {
201         if (savePointForTestMethod != null) {
202             TestRun.enterNoMockingZone();
203 
204             try {
205                 savePointForTestMethod.rollback();
206                 savePointForTestMethod = null;
207 
208                 if (thrownByTest != null) {
209                     filterStackTrace(thrownByTest);
210                 }
211 
212                 Error expectationsFailure = RecordAndReplayExecution.endCurrentReplayIfAny();
213                 clearTestedObjectsIfAny();
214 
215                 if (expectationsFailure != null) {
216                     filterStackTrace(expectationsFailure);
217                     throw expectationsFailure;
218                 }
219             } finally {
220                 TestRun.finishCurrentTestExecution();
221                 TestRun.exitNoMockingZone();
222             }
223         }
224     }
225 
226     @Override
227     public void afterEach(@NonNull ExtensionContext context) {
228         if (savePointForTest != null) {
229             savePointForTest.rollback();
230             savePointForTest = null;
231         }
232     }
233 
234     @Override
235     public void afterAll(@NonNull ExtensionContext context) {
236         if (savePointForTestClass != null && isRegularTestClass(context)) {
237             savePointForTestClass.rollback();
238             savePointForTestClass = null;
239 
240             clearFieldTypeRedefinitions();
241             TestRun.setCurrentTestClass(null);
242         }
243     }
244 
245     private static class ParamValueInitContext {
246         private final Object instance;
247         private final Class<?> clazz;
248         private final Method method;
249         private final String warning;
250 
251         ParamValueInitContext(Object instance, Class<?> clazz, Method method, String warning) {
252             this.instance = instance;
253             this.clazz = clazz;
254             this.method = method;
255             this.warning = warning;
256         }
257 
258         boolean isBeforeAllMethod() {
259             return method.getDeclaredAnnotation(BeforeAll.class) != null;
260         }
261 
262         boolean isBeforeEachMethod() {
263             return method.getDeclaredAnnotation(BeforeEach.class) != null;
264         }
265 
266         String displayClass() {
267             if (clazz == null) {
268                 return "<no class reference>";
269             }
270             return clazz.getName();
271         }
272 
273         String displayMethod() {
274             if (method == null) {
275                 return "<no method reference>";
276             }
277             String methodPrefix = isBeforeAllMethod() ? "@BeforeAll " : isBeforeEachMethod() ? "@BeforeEach " : "";
278             String args = Arrays.stream(method.getParameterTypes()).map(Class::getName)
279                     .collect(Collectors.joining(", "));
280             return methodPrefix + method.getName() + "(" + args + ")";
281         }
282 
283         @Override
284         public String toString() {
285             return "ParamContext{" + "hasInstance=" + (instance == null ? "false" : "true") + ", class=" + clazz
286                     + ", method=" + method + '}';
287         }
288     }
289 }