1
2
3
4
5 package mockit.internal.injection;
6
7 import static mockit.internal.injection.InjectionPoint.SERVLET_CLASS;
8 import static mockit.internal.injection.InjectionPoint.isServlet;
9 import static mockit.internal.reflection.ParameterReflection.getParameterCount;
10 import static mockit.internal.util.Utilities.NO_ARGS;
11
12 import edu.umd.cs.findbugs.annotations.NonNull;
13 import edu.umd.cs.findbugs.annotations.Nullable;
14
15 import java.lang.annotation.Annotation;
16 import java.lang.reflect.Method;
17 import java.util.ArrayList;
18 import java.util.IdentityHashMap;
19 import java.util.List;
20 import java.util.Map;
21 import java.util.Map.Entry;
22
23 import javax.annotation.PostConstruct;
24 import javax.annotation.PreDestroy;
25 import javax.servlet.ServletConfig;
26
27 import mockit.internal.reflection.MethodReflection;
28 import mockit.internal.state.TestRun;
29
30 public final class LifecycleMethods {
31 @NonNull
32 private final List<Class<?>> classesSearched;
33 @NonNull
34 private final Map<Class<?>, Method> initializationMethods;
35 @NonNull
36 private final Map<Class<?>, Method> terminationMethods;
37 @NonNull
38 private final Map<Class<?>, Object> objectsWithTerminationMethodsToExecute;
39 @Nullable
40 private Object servletConfig;
41
42 LifecycleMethods() {
43 classesSearched = new ArrayList<>();
44 initializationMethods = new IdentityHashMap<>();
45 terminationMethods = new IdentityHashMap<>();
46 objectsWithTerminationMethodsToExecute = new IdentityHashMap<>();
47 }
48
49 public void findLifecycleMethods(@NonNull Class<?> testedClass) {
50 if (testedClass.isInterface() || classesSearched.contains(testedClass)) {
51 return;
52 }
53
54 boolean isServlet = isServlet(testedClass);
55 Class<?> classWithLifecycleMethods = testedClass;
56
57 do {
58 findLifecycleMethodsInSingleClass(isServlet, classWithLifecycleMethods);
59 classWithLifecycleMethods = classWithLifecycleMethods.getSuperclass();
60 } while (classWithLifecycleMethods != Object.class);
61
62 classesSearched.add(testedClass);
63 }
64
65 private void findLifecycleMethodsInSingleClass(boolean isServlet, @NonNull Class<?> classWithLifecycleMethods) {
66 Method initializationMethod = null;
67 Method terminationMethod = null;
68 int methodsFoundInSameClass = 0;
69
70 for (Method method : classWithLifecycleMethods.getDeclaredMethods()) {
71 if (method.isSynthetic()) {
72 continue;
73 }
74
75 if (initializationMethod == null && isInitializationMethod(method, isServlet)) {
76 initializationMethods.put(classWithLifecycleMethods, method);
77 initializationMethod = method;
78 methodsFoundInSameClass++;
79 } else if (terminationMethod == null && isTerminationMethod(method, isServlet)) {
80 terminationMethods.put(classWithLifecycleMethods, method);
81 terminationMethod = method;
82 methodsFoundInSameClass++;
83 }
84
85 if (methodsFoundInSameClass == 2) {
86 break;
87 }
88 }
89 }
90
91 private static boolean isInitializationMethod(@NonNull Method method, boolean isServlet) {
92 if (hasLifecycleAnnotation(method, true)) {
93 return true;
94 }
95
96 if (isServlet && "init".equals(method.getName())) {
97 Class<?>[] parameterTypes = method.getParameterTypes();
98 return parameterTypes.length == 1 && parameterTypes[0] == ServletConfig.class;
99 }
100
101 return false;
102 }
103
104 private static boolean hasLifecycleAnnotation(@NonNull Method method, boolean postConstruct) {
105 try {
106 Class<? extends Annotation> lifecycleAnnotation = postConstruct ? PostConstruct.class : PreDestroy.class;
107
108 if (method.isAnnotationPresent(lifecycleAnnotation)) {
109 return true;
110 }
111 } catch (NoClassDefFoundError ignore) {
112 }
113
114 return false;
115 }
116
117 private static boolean isTerminationMethod(@NonNull Method method, boolean isServlet) {
118 return hasLifecycleAnnotation(method, false)
119 || isServlet && "destroy".equals(method.getName()) && getParameterCount(method) == 0;
120 }
121
122 public void executeInitializationMethodsIfAny(@NonNull Class<?> testedClass, @NonNull Object testedObject) {
123 Class<?> superclass = testedClass.getSuperclass();
124
125 if (superclass != Object.class) {
126 executeInitializationMethodsIfAny(superclass, testedObject);
127 }
128
129 Method postConstructMethod = initializationMethods.get(testedClass);
130
131 if (postConstructMethod != null) {
132 executeInitializationMethod(testedObject, postConstructMethod);
133 }
134
135 Method preDestroyMethod = terminationMethods.get(testedClass);
136
137 if (preDestroyMethod != null) {
138 objectsWithTerminationMethodsToExecute.put(testedClass, testedObject);
139 }
140 }
141
142 private void executeInitializationMethod(@NonNull Object testedObject, @NonNull Method initializationMethod) {
143 Object[] args = NO_ARGS;
144
145 if ("init".equals(initializationMethod.getName()) && getParameterCount(initializationMethod) == 1) {
146 args = new Object[] { servletConfig };
147 }
148
149 TestRun.exitNoMockingZone();
150
151 try {
152 MethodReflection.invoke(testedObject, initializationMethod, args);
153 } finally {
154 TestRun.enterNoMockingZone();
155 }
156 }
157
158 void executeTerminationMethodsIfAny() {
159 try {
160 for (Entry<Class<?>, Object> testedClassAndObject : objectsWithTerminationMethodsToExecute.entrySet()) {
161 executeTerminationMethod(testedClassAndObject.getKey(), testedClassAndObject.getValue());
162 }
163 } finally {
164 objectsWithTerminationMethodsToExecute.clear();
165 }
166 }
167
168 private void executeTerminationMethod(@NonNull Class<?> testedClass, @NonNull Object testedObject) {
169 Method terminationMethod = terminationMethods.get(testedClass);
170 TestRun.exitNoMockingZone();
171
172 try {
173 MethodReflection.invoke(testedObject, terminationMethod);
174 } catch (RuntimeException | AssertionError ignore) {
175 } finally {
176 TestRun.enterNoMockingZone();
177 }
178 }
179
180 void getServletConfigForInitMethodsIfAny(@NonNull List<? extends InjectionProvider> injectables,
181 @NonNull Object testClassInstance) {
182 if (SERVLET_CLASS != null) {
183 for (InjectionProvider injectable : injectables) {
184 if (injectable.getDeclaredType() == ServletConfig.class) {
185 servletConfig = injectable.getValue(testClassInstance);
186 break;
187 }
188 }
189 }
190 }
191 }