001    package edu.rice.cs.cunit.instrumentors;
002    
003    import edu.rice.cs.cunit.classFile.ClassFile;
004    import edu.rice.cs.cunit.classFile.MethodInfo;
005    import edu.rice.cs.cunit.classFile.attributes.CodeAttributeInfo;
006    import edu.rice.cs.cunit.classFile.code.InstructionList;
007    import edu.rice.cs.cunit.classFile.code.Opcode;
008    import edu.rice.cs.cunit.classFile.code.instructions.GenericInstruction;
009    import edu.rice.cs.cunit.util.Types;
010    
011    import java.util.ArrayList;
012    
013    /**
014     * Instrumentor that turns a synchronized methods into an unsynchronized method with a synchronized block.
015     * <p/>
016     * NOTE: This instrumentation strategy has to be run before SynchronizedBlockStrategy.
017     *
018     * @author Mathias Ricken
019     */
020    public class SynchronizedMethodToBlockStrategy implements IInstrumentationStrategy {
021        /**
022         * Methods to add.
023         */
024        protected ArrayList<MethodInfo> _newMethods = new ArrayList<MethodInfo>();
025    
026        /**
027         * Constructor.
028         */
029        public SynchronizedMethodToBlockStrategy() {
030            // nothing to do
031        }
032    
033        /**
034         * Instrument the class.
035         *
036         * @param cf class file info
037         */
038        public void instrument(ClassFile cf) {
039            //Debug.out.println("Instrumenting synchronized methods");
040    
041            // process methods
042            instrumentSynchronizedMethods(cf);
043        }
044    
045        /**
046         * Instrument synchronized methods in this class. Inside a non-native synchronized method, a synchronized
047         * block around the entire code is added, and the method is changed to unsynchronized.
048         *
049         * @param cf class file info
050         */
051        protected void instrumentSynchronizedMethods(ClassFile cf) {
052            // process all methods in this class
053            byte[] loadCode = new byte[]{Opcode.ALOAD_0};
054            byte[] loadClassCode = new byte[]{Opcode.LDC_W, 0, 0};
055            for(MethodInfo mi : cf.getMethods()) {
056                if (0 != (mi.getAccessFlags() & ClassFile.ACC_SYNCHRONIZED)) {
057                    // synchronized...
058                    if (0 == (mi.getAccessFlags() & ClassFile.ACC_NATIVE)) {
059                        // synchronized non-native...
060                        if (0 != (mi.getAccessFlags() & ClassFile.ACC_STATIC)) {
061                            // synchronized static non-native...
062                            //Debug.out.println("Instrumenting synchronized static method...");
063    
064                            // instruction to load the class object
065                            Types.bytesFromShort(cf.getConstantPool().indexOf(cf.getThisClass()),
066                                loadClassCode,
067                                1);
068    
069                            // instrument static method
070                            instrumentSynchronizedMethod(cf,mi,loadClassCode, 2);
071                        }
072                        else {
073                            // synchronized non-static non-native...
074                            //Debug.out.println("Instrumenting synchronized method...");
075    
076                            // instrument static method
077                            instrumentSynchronizedMethod(cf,mi,loadCode, 2);
078                        }
079    
080                        // erase the synchronized flag
081                        short newFlags = (short)(mi.getAccessFlags() & ~(ClassFile.ACC_SYNCHRONIZED));
082                        mi.setAccessFlags(newFlags);
083                    }
084                    else {
085                        // synchronized native...
086                        // can't do it
087                    }
088                }
089            }
090        }
091    
092        /**
093         * Instrument a synchronized method's code blocks. This turns any method <code> synchronized T method(T t...) { abc(); } </code>
094         * into <code> T method(T t...) { synchronized (?) { abc(); } } </code>
095         * <p/>
096         *
097         * @param cf        class file
098         * @param mi        method info
099         * @param loadLockCode instructions to load the object that gets locked
100         * @param reqdStackSize required state size
101         */
102        protected void instrumentSynchronizedMethod(ClassFile cf,
103                                                    MethodInfo mi,
104                                                    byte[] loadLockCode,
105                                                    int reqdStackSize) {
106            CodeAttributeInfo codeAttr = mi.getCodeAttributeInfo();
107            InstructionList il = new InstructionList(codeAttr.getCode());
108            GenericInstruction monitorExitInstr = new GenericInstruction(new byte[] {Opcode.MONITOREXIT});
109            boolean result;
110    
111            // insert monitorenter
112            if (0 < loadLockCode.length) {
113                InstructionList loadil = new InstructionList(loadLockCode);
114                do {
115                    il.insertBeforeInstr(loadil.getInstr(), codeAttr);
116                    // since we just inserted an instruction, we must be able to advance the PC
117                    result = il.advanceIndex();
118                    assert result;
119                } while(loadil.advanceIndex());
120            }
121            il.insertBeforeInstr(new GenericInstruction(new byte[] {Opcode.MONITORENTER}), codeAttr);
122            // since we just inserted an instruction, we must be able to advance the PC
123            result = il.advanceIndex();
124            assert result;
125    
126            int exceptionEndPC = il.getPCFromIndex(il.getIndex());
127    
128            do {
129                if (Opcode.isReturn(il.getOpcode())) {
130                    // insert monitorexit
131                    if (0 < loadLockCode.length) {
132                        InstructionList loadil = new InstructionList(loadLockCode);
133                        do {
134                            il.insertInstr(loadil.getInstr(), codeAttr);
135                            // since we just inserted an instruction, we must be able to advance the PC
136                            result = il.advanceIndex();
137                            assert result;
138                        } while(loadil.advanceIndex());
139                    }
140                    il.insertInstr(monitorExitInstr, codeAttr);
141                    // since we just inserted an instruction, we must be able to advance the PC
142                    result = il.advanceIndex();
143                    assert result;
144                }
145                exceptionEndPC = il.getPCFromIndex(il.getIndex());
146            } while(il.advanceIndex());
147    
148            // insert exception handler
149            int exceptionHandlerPC = il.getPCFromIndex(il.getIndex());
150            if (0 < loadLockCode.length) {
151                InstructionList loadil = new InstructionList(loadLockCode);
152                do {
153                    il.insertInstr(loadil.getInstr(), codeAttr);
154                    // since we just inserted an instruction, we must be able to advance the PC
155                    result = il.advanceIndex();
156                    assert result;
157                } while(loadil.advanceIndex());
158            }
159            il.insertInstr(monitorExitInstr, codeAttr);
160            // since we just inserted an instruction, we must be able to advance the PC
161            result = il.advanceIndex();
162            assert result;
163    
164            // rethrow
165            il.insertInstr(new GenericInstruction(new byte[] {Opcode.ATHROW}), codeAttr);
166            // since we just inserted an instruction, we must be able to advance the PC
167            result = il.advanceIndex();
168            assert result;
169    
170            // update exceptions list
171            CodeAttributeInfo.ExceptionTableEntry[] excTable =
172                new CodeAttributeInfo.ExceptionTableEntry[codeAttr.getExceptionTableEntries().length+1];
173            System.arraycopy(codeAttr.getExceptionTableEntries(), 0, excTable, 0, codeAttr.getExceptionTableEntries().length);
174            excTable[excTable.length-1] = new CodeAttributeInfo.ExceptionTableEntry((short)0,
175                                                                                    exceptionEndPC,
176                                                                                    exceptionHandlerPC,
177                                                                                    (short)0);
178            codeAttr.setExceptionTableEntries(excTable);
179    
180            codeAttr.setCode(il.getCode());
181    
182            // make sure state size is at least reqdStackSize
183            CodeAttributeInfo.CodeProperties cp = mi.getCodeAttributeInfo().getProperties();
184            cp.maxStack = (short)Math.max(reqdStackSize, cp.maxStack);
185            mi.getCodeAttributeInfo().setProperties(cp.maxStack, cp.maxLocals);
186        }
187    
188        /**
189         * Instrumentation of all classes is done.
190         */
191        public void done() {
192            // nothing to do
193        }
194    }