1
2
3
4
5
6 package mockit.internal.expectations.transformation;
7
8 import static mockit.asm.jvmConstants.Opcodes.ALOAD;
9 import static mockit.asm.jvmConstants.Opcodes.DCONST_0;
10 import static mockit.asm.jvmConstants.Opcodes.FCONST_0;
11 import static mockit.asm.jvmConstants.Opcodes.GETFIELD;
12 import static mockit.asm.jvmConstants.Opcodes.GETSTATIC;
13 import static mockit.asm.jvmConstants.Opcodes.ICONST_0;
14 import static mockit.asm.jvmConstants.Opcodes.INVOKESTATIC;
15 import static mockit.asm.jvmConstants.Opcodes.INVOKEVIRTUAL;
16 import static mockit.asm.jvmConstants.Opcodes.LCONST_0;
17 import static mockit.asm.jvmConstants.Opcodes.NEW;
18 import static mockit.asm.jvmConstants.Opcodes.NEWARRAY;
19 import static mockit.asm.jvmConstants.Opcodes.POP;
20 import static mockit.asm.jvmConstants.Opcodes.PUTFIELD;
21 import static mockit.asm.jvmConstants.Opcodes.PUTSTATIC;
22 import static mockit.asm.jvmConstants.Opcodes.RETURN;
23 import static mockit.internal.util.TypeConversionBytecode.isBoxing;
24 import static mockit.internal.util.TypeConversionBytecode.isUnboxing;
25
26 import edu.umd.cs.findbugs.annotations.NonNull;
27 import edu.umd.cs.findbugs.annotations.Nullable;
28
29 import mockit.asm.controlFlow.Label;
30 import mockit.asm.jvmConstants.JVMInstruction;
31 import mockit.asm.methods.MethodWriter;
32 import mockit.asm.methods.WrappingMethodVisitor;
33 import mockit.asm.types.JavaType;
34
35 import org.checkerframework.checker.index.qual.NonNegative;
36
37 public final class InvocationBlockModifier extends WrappingMethodVisitor {
38 private static final String CLASS_DESC = "mockit/internal/expectations/ActiveInvocations";
39
40
41 @NonNull
42 private final String blockOwner;
43
44
45 @NonNegative
46 private int stackSize;
47
48
49 @NonNull
50 final ArgumentMatching argumentMatching;
51 @NonNull
52 final ArgumentCapturing argumentCapturing;
53 private boolean justAfterWithCaptureInvocation;
54
55
56 @NonNegative
57 private int lastLoadedVarIndex;
58
59 InvocationBlockModifier(@NonNull MethodWriter mw, @NonNull String blockOwner) {
60 super(mw);
61 this.blockOwner = blockOwner;
62 argumentMatching = new ArgumentMatching(this);
63 argumentCapturing = new ArgumentCapturing(this);
64 }
65
66 void generateCallToActiveInvocationsMethod(@NonNull String name) {
67 mw.visitMethodInsn(INVOKESTATIC, CLASS_DESC, name, "()V", false);
68 }
69
70 void generateCallToActiveInvocationsMethod(@NonNull String name, @NonNull String desc) {
71 visitMethodInstruction(INVOKESTATIC, CLASS_DESC, name, desc, false);
72 }
73
74 @Override
75 public void visitFieldInsn(@NonNegative int opcode, @NonNull String owner, @NonNull String name,
76 @NonNull String desc) {
77 boolean getField = opcode == GETFIELD;
78
79 if ((getField || opcode == PUTFIELD) && blockOwner.equals(owner)) {
80 if (name.indexOf('$') >= 1) {
81
82 } else if (getField && ArgumentMatching.isAnyField(name)) {
83 argumentMatching.generateCodeToAddArgumentMatcherForAnyField(owner, name, desc);
84 argumentMatching.addMatcher(stackSize);
85 return;
86 } else if (!getField && generateCodeThatReplacesAssignmentToSpecialField(name)) {
87 visitInsn(POP);
88 return;
89 }
90 }
91
92 stackSize += stackSizeVariationForFieldAccess(opcode, desc);
93 mw.visitFieldInsn(opcode, owner, name, desc);
94 }
95
96 private boolean generateCodeThatReplacesAssignmentToSpecialField(@NonNull String fieldName) {
97 if ("result".equals(fieldName)) {
98 generateCallToActiveInvocationsMethod("addResult", "(Ljava/lang/Object;)V");
99 return true;
100 }
101
102 if ("times".equals(fieldName) || "minTimes".equals(fieldName) || "maxTimes".equals(fieldName)) {
103 generateCallToActiveInvocationsMethod(fieldName, "(I)V");
104 return true;
105 }
106
107 return false;
108 }
109
110 private static int stackSizeVariationForFieldAccess(@NonNegative int opcode, @NonNull String fieldType) {
111 char c = fieldType.charAt(0);
112 boolean twoByteType = c == 'D' || c == 'J';
113
114 switch (opcode) {
115 case GETSTATIC:
116 return twoByteType ? 2 : 1;
117 case PUTSTATIC:
118 return twoByteType ? -2 : -1;
119 case GETFIELD:
120 return twoByteType ? 1 : 0;
121 case PUTFIELD:
122 return twoByteType ? -3 : -2;
123 default:
124 throw new IllegalArgumentException("Invalid field access opcode: " + opcode);
125 }
126 }
127
128 @Override
129 public void visitMethodInsn(@NonNegative int opcode, @NonNull String owner, @NonNull String name,
130 @NonNull String desc, boolean itf) {
131 if (opcode == INVOKESTATIC && (isBoxing(owner, name, desc) || isAccessMethod(owner, name))) {
132
133
134 visitMethodInstruction(INVOKESTATIC, owner, name, desc, itf);
135 } else if (isCallToArgumentMatcher(opcode, owner, name, desc)) {
136 visitMethodInstruction(INVOKEVIRTUAL, owner, name, desc, itf);
137
138 boolean withCaptureMethod = "withCapture".equals(name);
139
140 if (argumentCapturing.registerMatcher(withCaptureMethod, desc, lastLoadedVarIndex)) {
141 justAfterWithCaptureInvocation = withCaptureMethod;
142 argumentMatching.addMatcher(stackSize);
143 }
144 } else if (isUnboxing(opcode, owner, desc)) {
145 if (justAfterWithCaptureInvocation) {
146 generateCodeToReplaceNullWithZeroOnTopOfStack(desc);
147 justAfterWithCaptureInvocation = false;
148 } else {
149 visitMethodInstruction(opcode, owner, name, desc, itf);
150 }
151 } else {
152 handleMockedOrNonMockedInvocation(opcode, owner, name, desc, itf);
153 }
154 }
155
156 private boolean isAccessMethod(@NonNull String methodOwner, @NonNull String name) {
157 return !methodOwner.equals(blockOwner) && name.startsWith("access$");
158 }
159
160 private void visitMethodInstruction(@NonNegative int opcode, @NonNull String owner, @NonNull String name,
161 @NonNull String desc, boolean itf) {
162 if (!"()V".equals(desc)) {
163 int argAndRetSize = JavaType.getArgumentsAndReturnSizes(desc);
164 int argSize = argAndRetSize >> 2;
165
166 if (opcode == INVOKESTATIC) {
167 argSize--;
168 }
169
170 stackSize -= argSize;
171
172 int retSize = argAndRetSize & 0x03;
173 stackSize += retSize;
174 } else if (opcode != INVOKESTATIC) {
175 stackSize--;
176 }
177
178 mw.visitMethodInsn(opcode, owner, name, desc, itf);
179 }
180
181 private boolean isCallToArgumentMatcher(@NonNegative int opcode, @NonNull String owner, @NonNull String name,
182 @NonNull String desc) {
183 return opcode == INVOKEVIRTUAL && owner.equals(blockOwner)
184 && ArgumentMatching.isCallToArgumentMatcher(name, desc);
185 }
186
187 private void generateCodeToReplaceNullWithZeroOnTopOfStack(@NonNull String unboxingMethodDesc) {
188 visitInsn(POP);
189
190 char primitiveTypeCode = unboxingMethodDesc.charAt(2);
191 int zeroOpcode;
192
193 switch (primitiveTypeCode) {
194 case 'J':
195 zeroOpcode = LCONST_0;
196 break;
197 case 'F':
198 zeroOpcode = FCONST_0;
199 break;
200 case 'D':
201 zeroOpcode = DCONST_0;
202 break;
203 default:
204 zeroOpcode = ICONST_0;
205 }
206
207 visitInsn(zeroOpcode);
208 }
209
210 private void handleMockedOrNonMockedInvocation(@NonNegative int opcode, @NonNull String owner, @NonNull String name,
211 @NonNull String desc, boolean itf) {
212 if (argumentMatching.getMatcherCount() == 0) {
213 visitMethodInstruction(opcode, owner, name, desc, itf);
214 } else {
215 boolean mockedInvocationUsingTheMatchers = argumentMatching.handleInvocationParameters(stackSize, desc);
216 visitMethodInstruction(opcode, owner, name, desc, itf);
217 handleArgumentCapturingIfNeeded(mockedInvocationUsingTheMatchers);
218 }
219 }
220
221 private void handleArgumentCapturingIfNeeded(boolean mockedInvocationUsingTheMatchers) {
222 if (mockedInvocationUsingTheMatchers) {
223 argumentCapturing.generateCallsToCaptureMatchedArgumentsIfPending();
224 }
225
226 justAfterWithCaptureInvocation = false;
227 }
228
229 @Override
230 public void visitLabel(@NonNull Label label) {
231 mw.visitLabel(label);
232
233 if (!label.isDebug()) {
234 stackSize = 0;
235 }
236 }
237
238 @Override
239 public void visitTypeInsn(@NonNegative int opcode, @NonNull String typeDesc) {
240 argumentCapturing.registerTypeToCaptureIfApplicable(opcode, typeDesc);
241
242 if (opcode == NEW) {
243 stackSize++;
244 }
245
246 mw.visitTypeInsn(opcode, typeDesc);
247 }
248
249 @Override
250 public void visitIntInsn(@NonNegative int opcode, int operand) {
251 if (opcode != NEWARRAY) {
252 stackSize++;
253 }
254
255 mw.visitIntInsn(opcode, operand);
256 }
257
258 @Override
259 public void visitVarInsn(@NonNegative int opcode, @NonNegative int varIndex) {
260 if (opcode == ALOAD) {
261 lastLoadedVarIndex = varIndex;
262 }
263
264 argumentCapturing.registerAssignmentToCaptureVariableIfApplicable(opcode, varIndex);
265 stackSize += JVMInstruction.SIZE[opcode];
266 mw.visitVarInsn(opcode, varIndex);
267 }
268
269 @Override
270 public void visitLdcInsn(@NonNull Object cst) {
271 stackSize++;
272
273 if (cst instanceof Long || cst instanceof Double) {
274 stackSize++;
275 }
276
277 mw.visitLdcInsn(cst);
278 }
279
280 @Override
281 public void visitJumpInsn(@NonNegative int opcode, @NonNull Label label) {
282 stackSize += JVMInstruction.SIZE[opcode];
283 mw.visitJumpInsn(opcode, label);
284 }
285
286 @Override
287 public void visitTableSwitchInsn(int min, int max, @NonNull Label dflt, @NonNull Label... labels) {
288 stackSize--;
289 mw.visitTableSwitchInsn(min, max, dflt, labels);
290 }
291
292 @Override
293 public void visitLookupSwitchInsn(@NonNull Label dflt, @NonNull int[] keys, @NonNull Label[] labels) {
294 stackSize--;
295 mw.visitLookupSwitchInsn(dflt, keys, labels);
296 }
297
298 @Override
299 public void visitMultiANewArrayInsn(@NonNull String desc, @NonNegative int dims) {
300 stackSize += 1 - dims;
301 mw.visitMultiANewArrayInsn(desc, dims);
302 }
303
304 @Override
305 public void visitInsn(@NonNegative int opcode) {
306 if (opcode == RETURN) {
307 generateCallToActiveInvocationsMethod("endInvocations");
308 } else {
309 stackSize += JVMInstruction.SIZE[opcode];
310 }
311
312 mw.visitInsn(opcode);
313 }
314
315 @Override
316 public void visitLocalVariable(@NonNull String name, @NonNull String desc, @Nullable String signature,
317 @NonNull Label start, @NonNull Label end, @NonNegative int index) {
318 if (signature != null) {
319 ArgumentCapturing.registerTypeToCaptureIntoListIfApplicable(index, signature);
320 }
321
322
323
324 if (end.position > 0) {
325 mw.visitLocalVariable(name, desc, signature, start, end, index);
326 }
327 }
328
329 @NonNull
330 MethodWriter getMethodWriter() {
331 return mw;
332 }
333 }