View Javadoc
1   /*
2    * MIT License
3    * Copyright (c) 2006-2025 JMockit developers
4    * See LICENSE file for full license text.
5    */
6   package mockit.internal.faking;
7   
8   import edu.umd.cs.findbugs.annotations.NonNull;
9   import edu.umd.cs.findbugs.annotations.Nullable;
10  
11  import java.lang.reflect.Member;
12  import java.lang.reflect.Method;
13  
14  import mockit.internal.expectations.invocation.MissingInvocation;
15  import mockit.internal.expectations.invocation.UnexpectedInvocation;
16  import mockit.internal.faking.FakeMethods.FakeMethod;
17  import mockit.internal.reflection.MethodReflection;
18  import mockit.internal.reflection.RealMethodOrConstructor;
19  import mockit.internal.util.ClassLoad;
20  
21  final class FakeState {
22      private static final ClassLoader THIS_CL = FakeState.class.getClassLoader();
23  
24      @NonNull
25      final FakeMethod fakeMethod;
26      @Nullable
27      private Method actualFakeMethod;
28      @Nullable
29      private Member realMethodOrConstructor;
30      @Nullable
31      private Object realClass;
32  
33      // Constraints pulled from the @Mock annotation; negative values indicate "no constraint".
34      private int expectedInvocations;
35      private int minExpectedInvocations;
36      private int maxExpectedInvocations;
37  
38      // Current fake invocation state:
39      private int invocationCount;
40      @Nullable
41      private ThreadLocal<FakeInvocation> proceedingInvocation;
42  
43      // Helper field just for synchronization:
44      @NonNull
45      private final Object invocationCountLock;
46  
47      FakeState(@NonNull FakeMethod fakeMethod) {
48          this.fakeMethod = fakeMethod;
49          invocationCountLock = new Object();
50          expectedInvocations = -1;
51          minExpectedInvocations = 0;
52          maxExpectedInvocations = -1;
53  
54          if (fakeMethod.canBeReentered()) {
55              makeReentrant();
56          }
57      }
58  
59      FakeState(@NonNull FakeState fakeState) {
60          fakeMethod = fakeState.fakeMethod;
61          actualFakeMethod = fakeState.actualFakeMethod;
62          realMethodOrConstructor = fakeState.realMethodOrConstructor;
63          invocationCountLock = new Object();
64          realClass = fakeState.realClass;
65          invocationCount = fakeState.invocationCount;
66          expectedInvocations = fakeState.expectedInvocations;
67          minExpectedInvocations = fakeState.minExpectedInvocations;
68          maxExpectedInvocations = fakeState.maxExpectedInvocations;
69  
70          if (fakeState.proceedingInvocation != null) {
71              makeReentrant();
72          }
73      }
74  
75      @NonNull
76      Class<?> getRealClass() {
77          return fakeMethod.getRealClass();
78      }
79  
80      private void makeReentrant() {
81          proceedingInvocation = new ThreadLocal<>();
82      }
83  
84      boolean isWithExpectations() {
85          return expectedInvocations >= 0 || minExpectedInvocations > 0 || maxExpectedInvocations >= 0;
86      }
87  
88      void setExpectedInvocations(int expectedInvocations) {
89          this.expectedInvocations = expectedInvocations;
90      }
91  
92      void setMinExpectedInvocations(int minExpectedInvocations) {
93          this.minExpectedInvocations = minExpectedInvocations;
94      }
95  
96      void setMaxExpectedInvocations(int maxExpectedInvocations) {
97          this.maxExpectedInvocations = maxExpectedInvocations;
98      }
99  
100     boolean update() {
101         if (proceedingInvocation != null) {
102             FakeInvocation invocation = proceedingInvocation.get();
103 
104             if (invocation != null && invocation.proceeding) {
105                 invocation.proceeding = false;
106                 return false;
107             }
108         }
109 
110         int timesInvoked;
111 
112         synchronized (invocationCountLock) {
113             timesInvoked = ++invocationCount;
114         }
115 
116         verifyUnexpectedInvocation(timesInvoked);
117 
118         return true;
119     }
120 
121     private void verifyUnexpectedInvocation(int timesInvoked) {
122         if (expectedInvocations >= 0 && timesInvoked > expectedInvocations) {
123             throw new UnexpectedInvocation(fakeMethod.errorMessage("exactly", expectedInvocations, timesInvoked));
124         }
125 
126         if (maxExpectedInvocations >= 0 && timesInvoked > maxExpectedInvocations) {
127             throw new UnexpectedInvocation(fakeMethod.errorMessage("at most", maxExpectedInvocations, timesInvoked));
128         }
129     }
130 
131     void verifyMissingInvocations() {
132         int timesInvoked = getTimesInvoked();
133 
134         if (expectedInvocations >= 0 && timesInvoked < expectedInvocations) {
135             throw new MissingInvocation(fakeMethod.errorMessage("exactly", expectedInvocations, timesInvoked));
136         }
137 
138         if (minExpectedInvocations > 0 && timesInvoked < minExpectedInvocations) {
139             throw new MissingInvocation(fakeMethod.errorMessage("at least", minExpectedInvocations, timesInvoked));
140         }
141     }
142 
143     int getTimesInvoked() {
144         synchronized (invocationCountLock) {
145             return invocationCount;
146         }
147     }
148 
149     void reset() {
150         synchronized (invocationCountLock) {
151             invocationCount = 0;
152         }
153     }
154 
155     @NonNull
156     Member getRealMethodOrConstructor(@NonNull String fakedClassDesc, @NonNull String fakedMethodName,
157             @NonNull String fakedMethodDesc) {
158         Class<?> fakedClass = ClassLoad.loadFromLoader(THIS_CL, fakedClassDesc.replace('/', '.'));
159         return getRealMethodOrConstructor(fakedClass, fakedMethodName, fakedMethodDesc);
160     }
161 
162     @NonNull
163     Member getRealMethodOrConstructor(@NonNull Class<?> fakedClass, @NonNull String fakedMethodName,
164             @NonNull String fakedMethodDesc) {
165         Member member = realMethodOrConstructor;
166 
167         if (member == null || !fakedClass.equals(realClass)) {
168             String memberName = "$init".equals(fakedMethodName) ? "<init>" : fakedMethodName;
169 
170             RealMethodOrConstructor realMember;
171             try {
172                 realMember = new RealMethodOrConstructor(fakedClass, memberName, fakedMethodDesc);
173             } catch (NoSuchMethodException e) {
174                 throw new RuntimeException(e);
175             }
176 
177             member = realMember.getMember();
178 
179             if (!fakeMethod.isAdvice) {
180                 realMethodOrConstructor = member;
181                 realClass = fakedClass;
182             }
183         }
184 
185         return member;
186     }
187 
188     boolean shouldProceedIntoRealImplementation(@Nullable Object fake, @NonNull String classDesc) {
189         if (proceedingInvocation != null) {
190             FakeInvocation pendingInvocation = proceedingInvocation.get();
191 
192             // noinspection RedundantIfStatement
193             if (pendingInvocation != null && pendingInvocation.isMethodInSuperclass(fake, classDesc)) {
194                 return true;
195             }
196         }
197 
198         return false;
199     }
200 
201     void prepareToProceed(@NonNull FakeInvocation invocation) {
202         if (proceedingInvocation == null) {
203             throw new UnsupportedOperationException("Cannot proceed into abstract/interface method");
204         }
205 
206         if (fakeMethod.isForNativeMethod()) {
207             throw new UnsupportedOperationException("Cannot proceed into real implementation of native method");
208         }
209 
210         FakeInvocation previousInvocation = proceedingInvocation.get();
211 
212         if (previousInvocation != null) {
213             invocation.setPrevious(previousInvocation);
214         }
215 
216         proceedingInvocation.set(invocation);
217     }
218 
219     void prepareToProceedFromNonRecursiveFake(@NonNull FakeInvocation invocation) {
220         assert proceedingInvocation != null;
221         proceedingInvocation.set(invocation);
222     }
223 
224     void clearProceedIndicator() {
225         assert proceedingInvocation != null;
226         FakeInvocation currentInvocation = proceedingInvocation.get();
227         FakeInvocation previousInvocation = (FakeInvocation) currentInvocation.getPrevious();
228         proceedingInvocation.set(previousInvocation);
229     }
230 
231     @NonNull
232     Method getFakeMethod(@NonNull Class<?> fakeClass, @NonNull Class<?>[] parameterTypes) {
233         if (actualFakeMethod == null) {
234             actualFakeMethod = MethodReflection.findCompatibleMethod(fakeClass, fakeMethod.name, parameterTypes);
235         }
236 
237         return actualFakeMethod;
238     }
239 }