View Javadoc
1   /*
2    * MIT License
3    * Copyright (c) 2006-2025 JMockit developers
4    * See LICENSE file for full license text.
5    */
6   package mockit.integration.springframework;
7   
8   import edu.umd.cs.findbugs.annotations.NonNull;
9   import edu.umd.cs.findbugs.annotations.Nullable;
10  
11  import mockit.internal.injection.BeanExporter;
12  import mockit.internal.injection.TestedClassInstantiations;
13  import mockit.internal.state.TestRun;
14  
15  import org.springframework.beans.factory.BeanDefinitionStoreException;
16  import org.springframework.web.context.support.StaticWebApplicationContext;
17  
18  /**
19   * A {@link org.springframework.web.context.WebApplicationContext} implementation which exposes the
20   * {@linkplain mockit.Tested @Tested} objects and their injected dependencies declared in the current test class.
21   */
22  public final class TestWebApplicationContext extends StaticWebApplicationContext {
23      @Override
24      @NonNull
25      public Object getBean(@NonNull String name) {
26          BeanExporter beanExporter = getBeanExporter();
27          return BeanLookup.getBean(beanExporter, name);
28      }
29  
30      @NonNull
31      private static BeanExporter getBeanExporter() {
32          TestedClassInstantiations testedClasses = TestRun.getTestedClassInstantiations();
33  
34          if (testedClasses == null) {
35              throw new BeanDefinitionStoreException("Test class does not define any @Tested fields");
36          }
37  
38          return testedClasses.getBeanExporter();
39      }
40  
41      @Override
42      @NonNull
43      public <T> T getBean(@NonNull String name, @Nullable Class<T> requiredType) {
44          BeanExporter beanExporter = getBeanExporter();
45          return BeanLookup.getBean(beanExporter, name, requiredType);
46      }
47  
48      @Override
49      @NonNull
50      public <T> T getBean(@NonNull Class<T> requiredType) {
51          BeanExporter beanExporter = getBeanExporter();
52          return BeanLookup.getBean(beanExporter, requiredType);
53      }
54  }