1
2
3
4
5 package mockit.internal.expectations.transformation;
6
7 import static mockit.asm.jvmConstants.Opcodes.ALOAD;
8 import static mockit.asm.jvmConstants.Opcodes.DCONST_0;
9 import static mockit.asm.jvmConstants.Opcodes.FCONST_0;
10 import static mockit.asm.jvmConstants.Opcodes.GETFIELD;
11 import static mockit.asm.jvmConstants.Opcodes.GETSTATIC;
12 import static mockit.asm.jvmConstants.Opcodes.ICONST_0;
13 import static mockit.asm.jvmConstants.Opcodes.INVOKESTATIC;
14 import static mockit.asm.jvmConstants.Opcodes.INVOKEVIRTUAL;
15 import static mockit.asm.jvmConstants.Opcodes.LCONST_0;
16 import static mockit.asm.jvmConstants.Opcodes.NEW;
17 import static mockit.asm.jvmConstants.Opcodes.NEWARRAY;
18 import static mockit.asm.jvmConstants.Opcodes.POP;
19 import static mockit.asm.jvmConstants.Opcodes.PUTFIELD;
20 import static mockit.asm.jvmConstants.Opcodes.PUTSTATIC;
21 import static mockit.asm.jvmConstants.Opcodes.RETURN;
22 import static mockit.internal.util.TypeConversionBytecode.isBoxing;
23 import static mockit.internal.util.TypeConversionBytecode.isUnboxing;
24
25 import edu.umd.cs.findbugs.annotations.NonNull;
26 import edu.umd.cs.findbugs.annotations.Nullable;
27
28 import mockit.asm.controlFlow.Label;
29 import mockit.asm.jvmConstants.JVMInstruction;
30 import mockit.asm.methods.MethodWriter;
31 import mockit.asm.methods.WrappingMethodVisitor;
32 import mockit.asm.types.JavaType;
33
34 import org.checkerframework.checker.index.qual.NonNegative;
35
36 public final class InvocationBlockModifier extends WrappingMethodVisitor {
37 private static final String CLASS_DESC = "mockit/internal/expectations/ActiveInvocations";
38
39
40 @NonNull
41 private final String blockOwner;
42
43
44 @NonNegative
45 private int stackSize;
46
47
48 @NonNull
49 final ArgumentMatching argumentMatching;
50 @NonNull
51 final ArgumentCapturing argumentCapturing;
52 private boolean justAfterWithCaptureInvocation;
53
54
55 @NonNegative
56 private int lastLoadedVarIndex;
57
58 InvocationBlockModifier(@NonNull MethodWriter mw, @NonNull String blockOwner) {
59 super(mw);
60 this.blockOwner = blockOwner;
61 argumentMatching = new ArgumentMatching(this);
62 argumentCapturing = new ArgumentCapturing(this);
63 }
64
65 void generateCallToActiveInvocationsMethod(@NonNull String name) {
66 mw.visitMethodInsn(INVOKESTATIC, CLASS_DESC, name, "()V", false);
67 }
68
69 void generateCallToActiveInvocationsMethod(@NonNull String name, @NonNull String desc) {
70 visitMethodInstruction(INVOKESTATIC, CLASS_DESC, name, desc, false);
71 }
72
73 @Override
74 public void visitFieldInsn(@NonNegative int opcode, @NonNull String owner, @NonNull String name,
75 @NonNull String desc) {
76 boolean getField = opcode == GETFIELD;
77
78 if ((getField || opcode == PUTFIELD) && blockOwner.equals(owner)) {
79 if (name.indexOf('$') > 0) {
80
81 } else if (getField && ArgumentMatching.isAnyField(name)) {
82 argumentMatching.generateCodeToAddArgumentMatcherForAnyField(owner, name, desc);
83 argumentMatching.addMatcher(stackSize);
84 return;
85 } else if (!getField && generateCodeThatReplacesAssignmentToSpecialField(name)) {
86 visitInsn(POP);
87 return;
88 }
89 }
90
91 stackSize += stackSizeVariationForFieldAccess(opcode, desc);
92 mw.visitFieldInsn(opcode, owner, name, desc);
93 }
94
95 private boolean generateCodeThatReplacesAssignmentToSpecialField(@NonNull String fieldName) {
96 if ("result".equals(fieldName)) {
97 generateCallToActiveInvocationsMethod("addResult", "(Ljava/lang/Object;)V");
98 return true;
99 }
100
101 if ("times".equals(fieldName) || "minTimes".equals(fieldName) || "maxTimes".equals(fieldName)) {
102 generateCallToActiveInvocationsMethod(fieldName, "(I)V");
103 return true;
104 }
105
106 return false;
107 }
108
109 private static int stackSizeVariationForFieldAccess(@NonNegative int opcode, @NonNull String fieldType) {
110 char c = fieldType.charAt(0);
111 boolean twoByteType = c == 'D' || c == 'J';
112
113 switch (opcode) {
114 case GETSTATIC:
115 return twoByteType ? 2 : 1;
116 case PUTSTATIC:
117 return twoByteType ? -2 : -1;
118 case GETFIELD:
119 return twoByteType ? 1 : 0;
120 case PUTFIELD:
121 return twoByteType ? -3 : -2;
122 default:
123 throw new IllegalArgumentException("Invalid field access opcode: " + opcode);
124 }
125 }
126
127 @Override
128 public void visitMethodInsn(@NonNegative int opcode, @NonNull String owner, @NonNull String name,
129 @NonNull String desc, boolean itf) {
130 if (opcode == INVOKESTATIC && (isBoxing(owner, name, desc) || isAccessMethod(owner, name))) {
131
132
133 visitMethodInstruction(INVOKESTATIC, owner, name, desc, itf);
134 } else if (isCallToArgumentMatcher(opcode, owner, name, desc)) {
135 visitMethodInstruction(INVOKEVIRTUAL, owner, name, desc, itf);
136
137 boolean withCaptureMethod = "withCapture".equals(name);
138
139 if (argumentCapturing.registerMatcher(withCaptureMethod, desc, lastLoadedVarIndex)) {
140 justAfterWithCaptureInvocation = withCaptureMethod;
141 argumentMatching.addMatcher(stackSize);
142 }
143 } else if (isUnboxing(opcode, owner, desc)) {
144 if (justAfterWithCaptureInvocation) {
145 generateCodeToReplaceNullWithZeroOnTopOfStack(desc);
146 justAfterWithCaptureInvocation = false;
147 } else {
148 visitMethodInstruction(opcode, owner, name, desc, itf);
149 }
150 } else {
151 handleMockedOrNonMockedInvocation(opcode, owner, name, desc, itf);
152 }
153 }
154
155 private boolean isAccessMethod(@NonNull String methodOwner, @NonNull String name) {
156 return !methodOwner.equals(blockOwner) && name.startsWith("access$");
157 }
158
159 private void visitMethodInstruction(@NonNegative int opcode, @NonNull String owner, @NonNull String name,
160 @NonNull String desc, boolean itf) {
161 if (!"()V".equals(desc)) {
162 int argAndRetSize = JavaType.getArgumentsAndReturnSizes(desc);
163 int argSize = argAndRetSize >> 2;
164
165 if (opcode == INVOKESTATIC) {
166 argSize--;
167 }
168
169 stackSize -= argSize;
170
171 int retSize = argAndRetSize & 0x03;
172 stackSize += retSize;
173 } else if (opcode != INVOKESTATIC) {
174 stackSize--;
175 }
176
177 mw.visitMethodInsn(opcode, owner, name, desc, itf);
178 }
179
180 private boolean isCallToArgumentMatcher(@NonNegative int opcode, @NonNull String owner, @NonNull String name,
181 @NonNull String desc) {
182 return opcode == INVOKEVIRTUAL && owner.equals(blockOwner)
183 && ArgumentMatching.isCallToArgumentMatcher(name, desc);
184 }
185
186 private void generateCodeToReplaceNullWithZeroOnTopOfStack(@NonNull String unboxingMethodDesc) {
187 visitInsn(POP);
188
189 char primitiveTypeCode = unboxingMethodDesc.charAt(2);
190 int zeroOpcode;
191
192 switch (primitiveTypeCode) {
193 case 'J':
194 zeroOpcode = LCONST_0;
195 break;
196 case 'F':
197 zeroOpcode = FCONST_0;
198 break;
199 case 'D':
200 zeroOpcode = DCONST_0;
201 break;
202 default:
203 zeroOpcode = ICONST_0;
204 }
205
206 visitInsn(zeroOpcode);
207 }
208
209 private void handleMockedOrNonMockedInvocation(@NonNegative int opcode, @NonNull String owner, @NonNull String name,
210 @NonNull String desc, boolean itf) {
211 if (argumentMatching.getMatcherCount() == 0) {
212 visitMethodInstruction(opcode, owner, name, desc, itf);
213 } else {
214 boolean mockedInvocationUsingTheMatchers = argumentMatching.handleInvocationParameters(stackSize, desc);
215 visitMethodInstruction(opcode, owner, name, desc, itf);
216 handleArgumentCapturingIfNeeded(mockedInvocationUsingTheMatchers);
217 }
218 }
219
220 private void handleArgumentCapturingIfNeeded(boolean mockedInvocationUsingTheMatchers) {
221 if (mockedInvocationUsingTheMatchers) {
222 argumentCapturing.generateCallsToCaptureMatchedArgumentsIfPending();
223 }
224
225 justAfterWithCaptureInvocation = false;
226 }
227
228 @Override
229 public void visitLabel(@NonNull Label label) {
230 mw.visitLabel(label);
231
232 if (!label.isDebug()) {
233 stackSize = 0;
234 }
235 }
236
237 @Override
238 public void visitTypeInsn(@NonNegative int opcode, @NonNull String typeDesc) {
239 argumentCapturing.registerTypeToCaptureIfApplicable(opcode, typeDesc);
240
241 if (opcode == NEW) {
242 stackSize++;
243 }
244
245 mw.visitTypeInsn(opcode, typeDesc);
246 }
247
248 @Override
249 public void visitIntInsn(@NonNegative int opcode, int operand) {
250 if (opcode != NEWARRAY) {
251 stackSize++;
252 }
253
254 mw.visitIntInsn(opcode, operand);
255 }
256
257 @Override
258 public void visitVarInsn(@NonNegative int opcode, @NonNegative int varIndex) {
259 if (opcode == ALOAD) {
260 lastLoadedVarIndex = varIndex;
261 }
262
263 argumentCapturing.registerAssignmentToCaptureVariableIfApplicable(opcode, varIndex);
264 stackSize += JVMInstruction.SIZE[opcode];
265 mw.visitVarInsn(opcode, varIndex);
266 }
267
268 @Override
269 public void visitLdcInsn(@NonNull Object cst) {
270 stackSize++;
271
272 if (cst instanceof Long || cst instanceof Double) {
273 stackSize++;
274 }
275
276 mw.visitLdcInsn(cst);
277 }
278
279 @Override
280 public void visitJumpInsn(@NonNegative int opcode, @NonNull Label label) {
281 stackSize += JVMInstruction.SIZE[opcode];
282 mw.visitJumpInsn(opcode, label);
283 }
284
285 @Override
286 public void visitTableSwitchInsn(int min, int max, @NonNull Label dflt, @NonNull Label... labels) {
287 stackSize--;
288 mw.visitTableSwitchInsn(min, max, dflt, labels);
289 }
290
291 @Override
292 public void visitLookupSwitchInsn(@NonNull Label dflt, @NonNull int[] keys, @NonNull Label[] labels) {
293 stackSize--;
294 mw.visitLookupSwitchInsn(dflt, keys, labels);
295 }
296
297 @Override
298 public void visitMultiANewArrayInsn(@NonNull String desc, @NonNegative int dims) {
299 stackSize += 1 - dims;
300 mw.visitMultiANewArrayInsn(desc, dims);
301 }
302
303 @Override
304 public void visitInsn(@NonNegative int opcode) {
305 if (opcode == RETURN) {
306 generateCallToActiveInvocationsMethod("endInvocations");
307 } else {
308 stackSize += JVMInstruction.SIZE[opcode];
309 }
310
311 mw.visitInsn(opcode);
312 }
313
314 @Override
315 public void visitLocalVariable(@NonNull String name, @NonNull String desc, @Nullable String signature,
316 @NonNull Label start, @NonNull Label end, @NonNegative int index) {
317 if (signature != null) {
318 ArgumentCapturing.registerTypeToCaptureIntoListIfApplicable(index, signature);
319 }
320
321
322
323 if (end.position > 0) {
324 mw.visitLocalVariable(name, desc, signature, start, end, index);
325 }
326 }
327
328 @NonNull
329 MethodWriter getMethodWriter() {
330 return mw;
331 }
332 }