View Javadoc
1   /*
2    * Copyright (c) 2006 JMockit developers
3    * This file is subject to the terms of the MIT license (see LICENSE.txt).
4    */
5   package mockit.internal.injection;
6   
7   import static mockit.internal.injection.InjectionPoint.SERVLET_CLASS;
8   import static mockit.internal.injection.InjectionPoint.isServlet;
9   import static mockit.internal.reflection.ParameterReflection.getParameterCount;
10  import static mockit.internal.util.Utilities.NO_ARGS;
11  
12  import edu.umd.cs.findbugs.annotations.NonNull;
13  import edu.umd.cs.findbugs.annotations.Nullable;
14  
15  import java.lang.annotation.Annotation;
16  import java.lang.reflect.Method;
17  import java.util.ArrayList;
18  import java.util.IdentityHashMap;
19  import java.util.List;
20  import java.util.Map;
21  import java.util.Map.Entry;
22  
23  import javax.annotation.PostConstruct;
24  import javax.annotation.PreDestroy;
25  import javax.servlet.ServletConfig;
26  
27  import mockit.internal.reflection.MethodReflection;
28  import mockit.internal.state.TestRun;
29  
30  public final class LifecycleMethods {
31      @NonNull
32      private final List<Class<?>> classesSearched;
33      @NonNull
34      private final Map<Class<?>, Method> initializationMethods;
35      @NonNull
36      private final Map<Class<?>, Method> terminationMethods;
37      @NonNull
38      private final Map<Class<?>, Object> objectsWithTerminationMethodsToExecute;
39      @Nullable
40      private Object servletConfig;
41  
42      LifecycleMethods() {
43          classesSearched = new ArrayList<>();
44          initializationMethods = new IdentityHashMap<>();
45          terminationMethods = new IdentityHashMap<>();
46          objectsWithTerminationMethodsToExecute = new IdentityHashMap<>();
47      }
48  
49      public void findLifecycleMethods(@NonNull Class<?> testedClass) {
50          if (testedClass.isInterface() || classesSearched.contains(testedClass)) {
51              return;
52          }
53  
54          boolean isServlet = isServlet(testedClass);
55          Class<?> classWithLifecycleMethods = testedClass;
56  
57          do {
58              findLifecycleMethodsInSingleClass(isServlet, classWithLifecycleMethods);
59              classWithLifecycleMethods = classWithLifecycleMethods.getSuperclass();
60          } while (classWithLifecycleMethods != Object.class);
61  
62          classesSearched.add(testedClass);
63      }
64  
65      private void findLifecycleMethodsInSingleClass(boolean isServlet, @NonNull Class<?> classWithLifecycleMethods) {
66          Method initializationMethod = null;
67          Method terminationMethod = null;
68          int methodsFoundInSameClass = 0;
69  
70          for (Method method : classWithLifecycleMethods.getDeclaredMethods()) {
71              if (method.isSynthetic()) {
72                  continue;
73              }
74  
75              if (initializationMethod == null && isInitializationMethod(method, isServlet)) {
76                  initializationMethods.put(classWithLifecycleMethods, method);
77                  initializationMethod = method;
78                  methodsFoundInSameClass++;
79              } else if (terminationMethod == null && isTerminationMethod(method, isServlet)) {
80                  terminationMethods.put(classWithLifecycleMethods, method);
81                  terminationMethod = method;
82                  methodsFoundInSameClass++;
83              }
84  
85              if (methodsFoundInSameClass == 2) {
86                  break;
87              }
88          }
89      }
90  
91      private static boolean isInitializationMethod(@NonNull Method method, boolean isServlet) {
92          if (hasLifecycleAnnotation(method, true)) {
93              return true;
94          }
95  
96          if (isServlet && "init".equals(method.getName())) {
97              Class<?>[] parameterTypes = method.getParameterTypes();
98              return parameterTypes.length == 1 && parameterTypes[0] == ServletConfig.class;
99          }
100 
101         return false;
102     }
103 
104     private static boolean hasLifecycleAnnotation(@NonNull Method method, boolean postConstruct) {
105         try {
106             Class<? extends Annotation> lifecycleAnnotation = postConstruct ? PostConstruct.class : PreDestroy.class;
107 
108             if (method.isAnnotationPresent(lifecycleAnnotation)) {
109                 return true;
110             }
111         } catch (NoClassDefFoundError ignore) {
112             /* can occur on JDK 9 */ }
113 
114         return false;
115     }
116 
117     private static boolean isTerminationMethod(@NonNull Method method, boolean isServlet) {
118         return hasLifecycleAnnotation(method, false)
119                 || isServlet && "destroy".equals(method.getName()) && getParameterCount(method) == 0;
120     }
121 
122     public void executeInitializationMethodsIfAny(@NonNull Class<?> testedClass, @NonNull Object testedObject) {
123         Class<?> superclass = testedClass.getSuperclass();
124 
125         if (superclass != Object.class) {
126             executeInitializationMethodsIfAny(superclass, testedObject);
127         }
128 
129         Method postConstructMethod = initializationMethods.get(testedClass);
130 
131         if (postConstructMethod != null) {
132             executeInitializationMethod(testedObject, postConstructMethod);
133         }
134 
135         Method preDestroyMethod = terminationMethods.get(testedClass);
136 
137         if (preDestroyMethod != null) {
138             objectsWithTerminationMethodsToExecute.put(testedClass, testedObject);
139         }
140     }
141 
142     private void executeInitializationMethod(@NonNull Object testedObject, @NonNull Method initializationMethod) {
143         Object[] args = NO_ARGS;
144 
145         if ("init".equals(initializationMethod.getName()) && getParameterCount(initializationMethod) == 1) {
146             args = new Object[] { servletConfig };
147         }
148 
149         TestRun.exitNoMockingZone();
150 
151         try {
152             MethodReflection.invoke(testedObject, initializationMethod, args);
153         } finally {
154             TestRun.enterNoMockingZone();
155         }
156     }
157 
158     void executeTerminationMethodsIfAny() {
159         try {
160             for (Entry<Class<?>, Object> testedClassAndObject : objectsWithTerminationMethodsToExecute.entrySet()) {
161                 executeTerminationMethod(testedClassAndObject.getKey(), testedClassAndObject.getValue());
162             }
163         } finally {
164             objectsWithTerminationMethodsToExecute.clear();
165         }
166     }
167 
168     private void executeTerminationMethod(@NonNull Class<?> testedClass, @NonNull Object testedObject) {
169         Method terminationMethod = terminationMethods.get(testedClass);
170         TestRun.exitNoMockingZone();
171 
172         try {
173             MethodReflection.invoke(testedObject, terminationMethod);
174         } catch (RuntimeException | AssertionError ignore) {
175         } finally {
176             TestRun.enterNoMockingZone();
177         }
178     }
179 
180     void getServletConfigForInitMethodsIfAny(@NonNull List<? extends InjectionProvider> injectables,
181             @NonNull Object testClassInstance) {
182         if (SERVLET_CLASS != null) {
183             for (InjectionProvider injectable : injectables) {
184                 if (injectable.getDeclaredType() == ServletConfig.class) {
185                     servletConfig = injectable.getValue(testClassInstance);
186                     break;
187                 }
188             }
189         }
190     }
191 }