1
2
3
4
5 package mockit.coverage;
6
7 import static java.lang.reflect.Modifier.isPublic;
8
9 import edu.umd.cs.findbugs.annotations.NonNull;
10 import edu.umd.cs.findbugs.annotations.Nullable;
11
12 import java.io.Serializable;
13 import java.lang.annotation.Annotation;
14 import java.lang.reflect.Method;
15 import java.util.HashMap;
16 import java.util.Map;
17
18 import mockit.internal.util.StackTrace;
19
20 import org.checkerframework.checker.index.qual.NonNegative;
21
22 public final class CallPoint implements Serializable {
23 private static final long serialVersionUID = 362727169057343840L;
24 private static final Map<StackTraceElement, Boolean> steCache = new HashMap<>();
25 private static final Class<? extends Annotation> testAnnotation;
26 private static final boolean checkTestAnnotationOnClass;
27 private static final boolean checkIfTestCaseSubclass;
28
29 static {
30 Class<?> annotation = getJUnitAnnotationIfAvailable();
31 boolean checkOnClassAlso = false;
32
33 if (annotation == null) {
34 annotation = getTestNGAnnotationIfAvailable();
35 checkOnClassAlso = true;
36 }
37
38
39 testAnnotation = (Class<? extends Annotation>) annotation;
40 checkTestAnnotationOnClass = checkOnClassAlso;
41 checkIfTestCaseSubclass = checkForJUnit3Availability();
42 }
43
44 @Nullable
45 private static Class<?> getJUnitAnnotationIfAvailable() {
46 try {
47
48 return Class.forName("org.junit.jupiter.api.Test");
49 } catch (ClassNotFoundException ignore) {
50
51 try {
52 return Class.forName("org.junit.Test");
53 } catch (ClassNotFoundException ignored) {
54 return null;
55 }
56 }
57 }
58
59 @Nullable
60 private static Class<?> getTestNGAnnotationIfAvailable() {
61 try {
62 return Class.forName("org.testng.annotations.Test");
63 } catch (ClassNotFoundException ignore) {
64
65 try {
66 return Class.forName("org.testng.Test");
67 } catch (ClassNotFoundException ignored) {
68 return null;
69 }
70 }
71 }
72
73 private static boolean checkForJUnit3Availability() {
74 try {
75 Class.forName("junit.framework.TestCase");
76 return true;
77 } catch (ClassNotFoundException ignore) {
78 return false;
79 }
80 }
81
82 @NonNull
83 private final StackTraceElement ste;
84 @NonNegative
85 private int repetitionCount;
86
87 private CallPoint(@NonNull StackTraceElement ste) {
88 this.ste = ste;
89 }
90
91 @NonNull
92 public StackTraceElement getStackTraceElement() {
93 return ste;
94 }
95
96 @NonNegative
97 public int getRepetitionCount() {
98 return repetitionCount;
99 }
100
101 public void incrementRepetitionCount() {
102 repetitionCount++;
103 }
104
105 public boolean isSameTestMethod(@NonNull CallPoint other) {
106 StackTraceElement thisSTE = ste;
107 StackTraceElement otherSTE = other.ste;
108 return thisSTE == otherSTE || thisSTE.getClassName().equals(otherSTE.getClassName())
109 && thisSTE.getMethodName().equals(otherSTE.getMethodName());
110 }
111
112 public boolean isSameLineInTestCode(@NonNull CallPoint other) {
113 return isSameTestMethod(other) && ste.getLineNumber() == other.ste.getLineNumber();
114 }
115
116 @Nullable
117 static CallPoint create(@NonNull Throwable newThrowable) {
118 StackTrace st = new StackTrace(newThrowable);
119 int n = st.getDepth();
120
121 for (int i = 2; i < n; i++) {
122 StackTraceElement ste = st.getElement(i);
123
124 if (isTestMethod(ste)) {
125 return new CallPoint(ste);
126 }
127 }
128
129 return null;
130 }
131
132 private static boolean isTestMethod(@NonNull StackTraceElement ste) {
133 if (steCache.containsKey(ste)) {
134 return steCache.get(ste);
135 }
136
137 boolean isTestMethod = false;
138
139 if (ste.getFileName() != null && ste.getLineNumber() >= 0) {
140 String className = ste.getClassName();
141
142 if (!isClassInExcludedPackage(className)) {
143 Class<?> aClass = loadClass(className);
144
145 if (aClass != null) {
146 isTestMethod = isTestMethod(aClass, ste.getMethodName());
147 }
148 }
149 }
150
151 steCache.put(ste, isTestMethod);
152 return isTestMethod;
153 }
154
155 private static boolean isClassInExcludedPackage(@NonNull String className) {
156 return className.startsWith("java.") || className.startsWith("javax.") || className.startsWith("sun.")
157 || className.startsWith("org.junit.") || className.startsWith("org.testng.")
158 || className.startsWith("mockit.");
159 }
160
161 @Nullable
162 private static Class<?> loadClass(@NonNull String className) {
163 try {
164 return Class.forName(className);
165 } catch (ClassNotFoundException | LinkageError ignore) {
166 return null;
167 }
168 }
169
170 private static boolean isTestMethod(@NonNull Class<?> testClass, @NonNull String methodName) {
171 if (checkTestAnnotationOnClass && testClass.isAnnotationPresent(testAnnotation)) {
172 return true;
173 }
174
175 Method method = findMethod(testClass, methodName);
176
177 return method != null && (containsATestFrameworkAnnotation(method.getDeclaredAnnotations())
178 || checkIfTestCaseSubclass && isJUnit3xTestMethod(testClass, method));
179 }
180
181 @Nullable
182 private static Method findMethod(@NonNull Class<?> aClass, @NonNull String name) {
183 try {
184 for (Method method : aClass.getDeclaredMethods()) {
185 if (method.getReturnType() == void.class && name.equals(method.getName())) {
186 return method;
187 }
188 }
189 } catch (NoClassDefFoundError ignore) {
190 }
191
192 return null;
193 }
194
195 private static boolean containsATestFrameworkAnnotation(@NonNull Annotation[] methodAnnotations) {
196 for (Annotation annotation : methodAnnotations) {
197 String annotationName = annotation.annotationType().getName();
198
199 if (annotationName.startsWith("org.junit.") || annotationName.startsWith("org.testng.")) {
200 return true;
201 }
202 }
203
204 return false;
205 }
206
207 private static boolean isJUnit3xTestMethod(@NonNull Class<?> aClass, @NonNull Method method) {
208 if (!isPublic(method.getModifiers()) || !method.getName().startsWith("test")) {
209 return false;
210 }
211
212 Class<?> superClass = aClass.getSuperclass();
213
214 while (superClass != Object.class) {
215 if ("junit.framework.TestCase".equals(superClass.getName())) {
216 return true;
217 }
218
219 superClass = superClass.getSuperclass();
220 }
221
222 return false;
223 }
224 }