1
2
3
4
5
6 package mockit.internal.faking;
7
8 import edu.umd.cs.findbugs.annotations.NonNull;
9 import edu.umd.cs.findbugs.annotations.Nullable;
10
11 import java.lang.reflect.Member;
12 import java.lang.reflect.Method;
13
14 import mockit.internal.expectations.invocation.MissingInvocation;
15 import mockit.internal.expectations.invocation.UnexpectedInvocation;
16 import mockit.internal.faking.FakeMethods.FakeMethod;
17 import mockit.internal.reflection.MethodReflection;
18 import mockit.internal.reflection.RealMethodOrConstructor;
19 import mockit.internal.util.ClassLoad;
20
21 final class FakeState {
22 private static final ClassLoader THIS_CL = FakeState.class.getClassLoader();
23
24 @NonNull
25 final FakeMethod fakeMethod;
26 @Nullable
27 private Method actualFakeMethod;
28 @Nullable
29 private Member realMethodOrConstructor;
30 @Nullable
31 private Object realClass;
32
33
34 private int expectedInvocations;
35 private int minExpectedInvocations;
36 private int maxExpectedInvocations;
37
38
39 private int invocationCount;
40 @Nullable
41 private ThreadLocal<FakeInvocation> proceedingInvocation;
42
43
44 @NonNull
45 private final Object invocationCountLock;
46
47 FakeState(@NonNull FakeMethod fakeMethod) {
48 this.fakeMethod = fakeMethod;
49 invocationCountLock = new Object();
50 expectedInvocations = -1;
51 minExpectedInvocations = 0;
52 maxExpectedInvocations = -1;
53
54 if (fakeMethod.canBeReentered()) {
55 makeReentrant();
56 }
57 }
58
59 FakeState(@NonNull FakeState fakeState) {
60 fakeMethod = fakeState.fakeMethod;
61 actualFakeMethod = fakeState.actualFakeMethod;
62 realMethodOrConstructor = fakeState.realMethodOrConstructor;
63 invocationCountLock = new Object();
64 realClass = fakeState.realClass;
65 invocationCount = fakeState.invocationCount;
66 expectedInvocations = fakeState.expectedInvocations;
67 minExpectedInvocations = fakeState.minExpectedInvocations;
68 maxExpectedInvocations = fakeState.maxExpectedInvocations;
69
70 if (fakeState.proceedingInvocation != null) {
71 makeReentrant();
72 }
73 }
74
75 @NonNull
76 Class<?> getRealClass() {
77 return fakeMethod.getRealClass();
78 }
79
80 private void makeReentrant() {
81 proceedingInvocation = new ThreadLocal<>();
82 }
83
84 boolean isWithExpectations() {
85 return expectedInvocations >= 0 || minExpectedInvocations > 0 || maxExpectedInvocations >= 0;
86 }
87
88 void setExpectedInvocations(int expectedInvocations) {
89 this.expectedInvocations = expectedInvocations;
90 }
91
92 void setMinExpectedInvocations(int minExpectedInvocations) {
93 this.minExpectedInvocations = minExpectedInvocations;
94 }
95
96 void setMaxExpectedInvocations(int maxExpectedInvocations) {
97 this.maxExpectedInvocations = maxExpectedInvocations;
98 }
99
100 boolean update() {
101 if (proceedingInvocation != null) {
102 FakeInvocation invocation = proceedingInvocation.get();
103
104 if (invocation != null && invocation.proceeding) {
105 invocation.proceeding = false;
106 return false;
107 }
108 }
109
110 int timesInvoked;
111
112 synchronized (invocationCountLock) {
113 timesInvoked = ++invocationCount;
114 }
115
116 verifyUnexpectedInvocation(timesInvoked);
117
118 return true;
119 }
120
121 private void verifyUnexpectedInvocation(int timesInvoked) {
122 if (expectedInvocations >= 0 && timesInvoked > expectedInvocations) {
123 throw new UnexpectedInvocation(fakeMethod.errorMessage("exactly", expectedInvocations, timesInvoked));
124 }
125
126 if (maxExpectedInvocations >= 0 && timesInvoked > maxExpectedInvocations) {
127 throw new UnexpectedInvocation(fakeMethod.errorMessage("at most", maxExpectedInvocations, timesInvoked));
128 }
129 }
130
131 void verifyMissingInvocations() {
132 int timesInvoked = getTimesInvoked();
133
134 if (expectedInvocations >= 0 && timesInvoked < expectedInvocations) {
135 throw new MissingInvocation(fakeMethod.errorMessage("exactly", expectedInvocations, timesInvoked));
136 }
137
138 if (minExpectedInvocations > 0 && timesInvoked < minExpectedInvocations) {
139 throw new MissingInvocation(fakeMethod.errorMessage("at least", minExpectedInvocations, timesInvoked));
140 }
141 }
142
143 int getTimesInvoked() {
144 synchronized (invocationCountLock) {
145 return invocationCount;
146 }
147 }
148
149 void reset() {
150 synchronized (invocationCountLock) {
151 invocationCount = 0;
152 }
153 }
154
155 @NonNull
156 Member getRealMethodOrConstructor(@NonNull String fakedClassDesc, @NonNull String fakedMethodName,
157 @NonNull String fakedMethodDesc) {
158 Class<?> fakedClass = ClassLoad.loadFromLoader(THIS_CL, fakedClassDesc.replace('/', '.'));
159 return getRealMethodOrConstructor(fakedClass, fakedMethodName, fakedMethodDesc);
160 }
161
162 @NonNull
163 Member getRealMethodOrConstructor(@NonNull Class<?> fakedClass, @NonNull String fakedMethodName,
164 @NonNull String fakedMethodDesc) {
165 Member member = realMethodOrConstructor;
166
167 if (member == null || !fakedClass.equals(realClass)) {
168 String memberName = "$init".equals(fakedMethodName) ? "<init>" : fakedMethodName;
169
170 RealMethodOrConstructor realMember;
171 try {
172 realMember = new RealMethodOrConstructor(fakedClass, memberName, fakedMethodDesc);
173 } catch (NoSuchMethodException e) {
174 throw new RuntimeException(e);
175 }
176
177 member = realMember.getMember();
178
179 if (!fakeMethod.isAdvice) {
180 realMethodOrConstructor = member;
181 realClass = fakedClass;
182 }
183 }
184
185 return member;
186 }
187
188 boolean shouldProceedIntoRealImplementation(@Nullable Object fake, @NonNull String classDesc) {
189 if (proceedingInvocation != null) {
190 FakeInvocation pendingInvocation = proceedingInvocation.get();
191
192
193 if (pendingInvocation != null && pendingInvocation.isMethodInSuperclass(fake, classDesc)) {
194 return true;
195 }
196 }
197
198 return false;
199 }
200
201 void prepareToProceed(@NonNull FakeInvocation invocation) {
202 if (proceedingInvocation == null) {
203 throw new UnsupportedOperationException("Cannot proceed into abstract/interface method");
204 }
205
206 if (fakeMethod.isForNativeMethod()) {
207 throw new UnsupportedOperationException("Cannot proceed into real implementation of native method");
208 }
209
210 FakeInvocation previousInvocation = proceedingInvocation.get();
211
212 if (previousInvocation != null) {
213 invocation.setPrevious(previousInvocation);
214 }
215
216 proceedingInvocation.set(invocation);
217 }
218
219 void prepareToProceedFromNonRecursiveFake(@NonNull FakeInvocation invocation) {
220 assert proceedingInvocation != null;
221 proceedingInvocation.set(invocation);
222 }
223
224 void clearProceedIndicator() {
225 assert proceedingInvocation != null;
226 FakeInvocation currentInvocation = proceedingInvocation.get();
227 FakeInvocation previousInvocation = (FakeInvocation) currentInvocation.getPrevious();
228 proceedingInvocation.set(previousInvocation);
229 }
230
231 @NonNull
232 Method getFakeMethod(@NonNull Class<?> fakeClass, @NonNull Class<?>[] parameterTypes) {
233 if (actualFakeMethod == null) {
234 actualFakeMethod = MethodReflection.findCompatibleMethod(fakeClass, fakeMethod.name, parameterTypes);
235 }
236
237 return actualFakeMethod;
238 }
239 }