001    package edu.rice.cs.cunit.instrumentors;
002    
003    import edu.rice.cs.cunit.SyncPointBuffer;
004    import edu.rice.cs.cunit.classFile.ClassFile;
005    import edu.rice.cs.cunit.classFile.MethodInfo;
006    import edu.rice.cs.cunit.classFile.attributes.CodeAttributeInfo;
007    import edu.rice.cs.cunit.classFile.code.InstructionList;
008    import edu.rice.cs.cunit.classFile.code.Opcode;
009    import edu.rice.cs.cunit.classFile.code.instructions.GenericInstruction;
010    import edu.rice.cs.cunit.classFile.code.instructions.ReferenceInstruction;
011    import edu.rice.cs.cunit.classFile.constantPool.*;
012    import edu.rice.cs.cunit.classFile.constantPool.visitors.CheckClassVisitor;
013    import edu.rice.cs.cunit.classFile.constantPool.visitors.CheckMethodVisitor;
014    import edu.rice.cs.cunit.classFile.constantPool.visitors.CheckUTFVisitor;
015    import edu.rice.cs.cunit.util.Types;
016    
017    /**
018     * Instrumentation strategy that adds code to java.lang.Thread to maintain a unique thread ID and an "old thread"
019     * flag.
020     *
021     * NOTE: This instrumentor has to be run AFTER the CompactSynchronizedDebugStrategy, otherwise the MONITORENTER and
022     * MONITOREXIT instructions will generate multiple sync points.
023     *
024     * @author Mathias Ricken
025     */
026    public class AssignThreadIDStrategy implements IInstrumentationStrategy {
027        /**
028         * Instrument the class.
029         *
030         * @param cf class file info
031         */
032        public void instrument(final ClassFile cf) {
033            // check if this is the target class
034            if (cf.getThisClassName().equals("java.lang.Thread")) {
035                int threadID = cf.addField(cf.getThisClass().getName().toString(), "$$$threadID$$$", "J", true,
036                                           (short)(ClassFile.ACC_PUBLIC|ClassFile.ACC_TRANSIENT));
037                int oldThread = cf.addField(cf.getThisClass().getName().toString(), "$$$oldThread$$$", "Z", true,
038                                            (short)(ClassFile.ACC_PUBLIC|ClassFile.ACC_TRANSIENT));
039                int nextThreadID = cf.addField("edu/rice/cs/cunit/SyncPointBuffer", "_nextThreadID", "J", false, (short)0);
040    
041                for(MethodInfo mi: cf.getMethods()) {
042                    if (mi.getName().toString().equals("<init>")) {
043                        processCtor(cf, mi, threadID, nextThreadID, oldThread);
044                    }
045                }
046            }
047        }
048    
049        /**
050         * Process this constructor.
051         * @param cf class file info
052         * @param mi method info
053         * @param threadID index of the threadID field
054         * @param nextThreadID index of the nextThreadID field
055         * @param oldThread index of the oldThread field
056         */
057        protected void processCtor(final ClassFile cf, final MethodInfo mi, int threadID, int nextThreadID, int oldThread) {
058            CodeAttributeInfo codeAttr = mi.getCodeAttributeInfo();
059            InstructionList il = new InstructionList(codeAttr.getCode());
060    
061            // check if this ctor calls other ctors
062            boolean ctorCalled = false;
063            do {
064                if (il.getOpcode()==Opcode.INVOKESPECIAL) {
065                    ReferenceInstruction ri = (ReferenceInstruction)il.getInstr();
066                    short method = Types.shortFromBytes(ri.getBytecode(), 1);
067                    MethodPoolInfo mpi = cf.getConstantPoolItem(method).execute(CheckMethodVisitor.singleton(), null);
068                    if ((mpi.getClassInfo().getName().toString().equals(cf.getThisClass().getName().toString())) &&
069                        (mpi.getNameAndType().getName().toString().equals("<init>"))) {
070                        ctorCalled = true;
071                        break;
072                    }
073                }
074            } while(il.advanceIndex());
075    
076            if (ctorCalled) {
077                // a ctor call was found, do not add code to this ctor
078                return;
079            }
080    
081            // inserts the equivalent of the Java following code, with sync points for the MONITORENTER and MONITOREXIT
082            // synchronized(edu.rice.cs.cunit.SyncPointBuffer.class) {
083            //     $$$threadID$$$ = ++edu.rice.cs.cunit.SyncPointBuffer._nextThreadID;
084            // }
085    
086            // TODO: How to handle thread ID overflow?
087    
088            // beginning of the method                                                              current/max stack usage
089            //    0 (pc     0): ldc2_w SP.THREADID_TRYMONITORENTER                                   2/ 2 <11L*2>
090            //    1 (pc     3): invokestatic (java/lang/Thread.currentThread)                        3/ 3 <11L*2> <threadobj>
091            //    2 (pc     6): getfield (java/lang/Thread.$$$threadID$$$)                           4/ 4 <11L*2> <tid*2>
092            //    3 (pc     9): invokestatic (edu/rice/cs/cunit/SyncPointBuffer.compactAdd)          0/ 4
093    
094            //    4 (pc    12): ldc_w (edu.rice.cs.cunit.SyncPointBuffer) // put class on stack      1/ 4 <lockobj>
095            //    5 (pc    15): monitorenter // and lock                                             0/ 4
096    
097            //    6 (pc    16): ldc2_w SP.THREADID_MONITORENTER                                      2/ 4 <9L*2>
098            //    7 (pc    19): invokestatic (java/lang/Thread.currentThread)                        3/ 4 <9L*2> <threadobj>
099            //    8 (pc    22): getfield (java/lang/Thread.$$$threadID$$$)                           4/ 4 <9L*2> <tid*2>
100            //    9 (pc    25): invokestatic (edu/rice/cs/cunit/SyncPointBuffer.compactAdd)          0/ 4
101    
102            //   10 (pc    28): aload_0                                                              1/ 4 <this>
103            //   11 (pc    29): getstatic (edu.rice.cs.cunit.SyncPointBuffer._nextThreadID)          3/ 4 <this> <nexttid*2>
104            //   12 (pc    32): dup2                                                                 5/ 5 <this> <nexttid*2> <nextoid*2>
105            //   13 (pc    33): lconst_1                                                             7/ 7 <this> <nexttid*2> <nexttid*2> <1L*2>
106            //   14 (pc    34): ladd                                                                 5/ 7 <this> <nexttid*2> <nexttid+1*2>
107            //   15 (pc    35): putstatic (edu.rice.cs.cunit.SyncPointBuffer._nextThreadID)          3/ 7 <this> <nexttid*2>
108            //   16 (pc    38): putfield (this.$$$threadID$$$)                                       0/ 7
109    
110            //   17 (pc    41): ldc2_w SP.THREADID_MONITOREXIT                                       2/ 7 <10L*2>
111            //   18 (pc    44): invokestatic (java/lang/Thread.currentThread)                        3/ 7 <10L*2> <threadobj>
112            //   19 (pc    47): getfield (java/lang/Thread.$$$threadID$$$)                           4/ 7 <10L*2> <tid*2>
113            //   20 (pc    50): invokestatic (edu/rice/cs/cunit/SyncPointBuffer.compactAdd)          0/ 7
114    
115            //   21 (pc    53): ldc_w (edu.rice.cs.cunit.SyncPointBuffer) // put class on stack      1/ 7 <lockobj>
116            //   22 (pc    56): monitorexit // and unlock                                            0/ 7
117    
118            il.setIndex(0);
119            ConstantPool cp = cf.getConstantPool();
120    
121            int threadIDIndex = cf.addField("java/lang/Thread", "$$$threadID$$$", "J", true,
122                                            (short)(ClassFile.ACC_PUBLIC | ClassFile.ACC_TRANSIENT));
123            int nextThreadIDIndex = cf.addField("edu/rice/cs/cunit/SyncPointBuffer", "_nextThreadID", "J", false,
124                                                (short)0);
125    
126            ReferenceInstruction addCallInstr = new ReferenceInstruction(Opcode.INVOKESTATIC, (short)0);
127            ReferenceInstruction loadMonitorEnterCodeIndexInstr = new ReferenceInstruction(Opcode.LDC2_W, (short)0);
128            ReferenceInstruction loadMonitorTryEnterCodeIndexInstr = new ReferenceInstruction(Opcode.LDC2_W, (short)0);
129            ReferenceInstruction loadMonitorExitCodeIndexInstr = new ReferenceInstruction(Opcode.LDC2_W, (short)0);
130            ReferenceInstruction getThreadIDInstr = new ReferenceInstruction(Opcode.GETFIELD, (short)0);
131            ReferenceInstruction currentThreadCallInstr = new ReferenceInstruction(Opcode.INVOKESTATIC, (short)0);
132    
133            int addCallIndex = cf.addMethodToConstantPool("edu/rice/cs/cunit/SyncPointBuffer",
134                                                          "compactAdd",
135                                                          "(JJ)V");
136            addCallInstr.setReference(addCallIndex);
137    
138            // add the code for entering a synchronized block
139            int monitorEnterCodeIndex = cf.addLongToConstantPool(SyncPointBuffer.SP.THREADID_MONITORENTER.intValue());
140            loadMonitorEnterCodeIndexInstr.setReference(monitorEnterCodeIndex);
141    
142            // add the code for trying to enter a synchronized block
143            int monitorTryEnterCodeIndex = cf.addLongToConstantPool(SyncPointBuffer.SP.THREADID_TRYMONITORENTER.intValue());
144            loadMonitorTryEnterCodeIndexInstr.setReference(monitorTryEnterCodeIndex);
145    
146            // add the field for the thread ID
147            getThreadIDInstr.setReference(threadIDIndex);
148    
149            // add the names for the call to Thread.currentThread()
150            int currentThreadCallIndex = cf.addMethodToConstantPool("java/lang/Thread",
151                                                                      "currentThread",
152                                                                      "()Ljava/lang/Thread;");
153            currentThreadCallInstr.setReference(currentThreadCallIndex);
154    
155            // add a new name
156            AUTFPoolInfo sprClassName = new ASCIIPoolInfo("edu/rice/cs/cunit/SyncPointBuffer", cp);
157            int[] l = cf.addConstantPoolItems(new APoolInfo[]{sprClassName});
158            sprClassName = cf.getConstantPoolItem(l[0]).execute(CheckUTFVisitor.singleton(), null);
159    
160            // add a new class
161            ClassPoolInfo sprClass = new ClassPoolInfo(sprClassName, cp);
162            l = cf.addConstantPoolItems(new APoolInfo[]{sprClass});
163            sprClass = cf.getConstantPoolItem(l[0]).execute(CheckClassVisitor.singleton(), null);
164            int sprClassIndex = l[0];
165    
166            // add the code for leaving a synchronized block
167            int monitorExitCodeIndex = cf.addLongToConstantPool(SyncPointBuffer.SP.THREADID_MONITOREXIT.intValue());
168            loadMonitorExitCodeIndexInstr.setReference(monitorExitCodeIndex);
169    
170            boolean result;
171    
172            // insert call to compactDebugAdd for THREADID_TRYMONITORENTER sync point
173            il.insertBeforeInstr(loadMonitorTryEnterCodeIndexInstr, mi.getCodeAttributeInfo());
174            result = il.advanceIndex();
175            assert result == true;
176            il.insertBeforeInstr(currentThreadCallInstr, mi.getCodeAttributeInfo());
177            result = il.advanceIndex();
178            assert result == true;
179            il.insertBeforeInstr(getThreadIDInstr, mi.getCodeAttributeInfo());
180            result = il.advanceIndex();
181            assert result == true;
182            il.insertBeforeInstr(addCallInstr, mi.getCodeAttributeInfo());
183            result = il.advanceIndex();
184            assert result == true;
185    
186            // insert ldc_w (edu.rice.cs.cunit.SyncPointBuffer)
187            byte[] instr = new byte[]{Opcode.LDC_W, 0, 0};
188            Types.bytesFromShort((short)(sprClassIndex & 0xffff), instr, 1);
189            il.insertBeforeInstr(new GenericInstruction(instr), codeAttr);
190            result = il.advanceIndex();
191            assert result == true; // since we just inserted an instruction, we must be able to advance the PC
192    
193            // insert monitorenter
194            il.insertBeforeInstr(new GenericInstruction(new byte[]{Opcode.MONITORENTER}), codeAttr);
195            result = il.advanceIndex();
196            assert result == true;
197    
198            // insert call to compactDebugAdd for THREADID_TRYMONITORENTER sync point
199            il.insertBeforeInstr(loadMonitorEnterCodeIndexInstr, mi.getCodeAttributeInfo());
200            result = il.advanceIndex();
201            assert result == true;
202            il.insertBeforeInstr(currentThreadCallInstr, mi.getCodeAttributeInfo());
203            result = il.advanceIndex();
204            assert result == true;
205            il.insertBeforeInstr(getThreadIDInstr, mi.getCodeAttributeInfo());
206            result = il.advanceIndex();
207            assert result == true;
208            il.insertBeforeInstr(addCallInstr, mi.getCodeAttributeInfo());
209            result = il.advanceIndex();
210            assert result == true;
211    
212            // insert aload_0
213            // this must happen here, or we can't get the instance for putfield in the right place on the stack
214            il.insertBeforeInstr(new GenericInstruction(new byte[]{Opcode.ALOAD_0}), codeAttr);
215            result = il.advanceIndex();
216            assert result == true;
217    
218            // insert getstatic (edu.rice.cs.cunit.SyncPointBuffer._nextThreadID)
219            instr[0] = Opcode.GETSTATIC;
220            Types.bytesFromShort((short)(nextThreadIDIndex & 0xffff), instr, 1);
221            il.insertBeforeInstr(new GenericInstruction(instr), codeAttr);
222            result = il.advanceIndex();
223            assert result == true;
224    
225            // insert dup2
226            il.insertBeforeInstr(new GenericInstruction(new byte[]{Opcode.DUP2}), codeAttr);
227            result = il.advanceIndex();
228            assert result == true;
229    
230            // insert lconst_1
231            il.insertBeforeInstr(new GenericInstruction(new byte[]{Opcode.LCONST_1}), codeAttr);
232            result = il.advanceIndex();
233            assert result == true;
234    
235            // insert ladd
236            il.insertBeforeInstr(new GenericInstruction(new byte[]{Opcode.LADD}), codeAttr);
237            result = il.advanceIndex();
238            assert result == true;
239    
240            // insert putstatic (edu.rice.cs.cunit.SyncPointBuffer._nextThreadID)
241            instr[0] = Opcode.PUTSTATIC;
242            il.insertBeforeInstr(new GenericInstruction(instr), codeAttr);
243            result = il.advanceIndex();
244            assert result == true;
245    
246            // insert putfield (this.$$$threadID$$$)
247            instr[0] = Opcode.PUTFIELD;
248            Types.bytesFromShort((short)(threadIDIndex & 0xffff), instr, 1);
249            il.insertBeforeInstr(new GenericInstruction(instr), codeAttr);
250            result = il.advanceIndex();
251            assert result == true;
252    
253            // insert call to compactDebugAdd for OBJID_TRYMONITORENTER sync point
254            il.insertBeforeInstr(loadMonitorExitCodeIndexInstr, mi.getCodeAttributeInfo());
255            result = il.advanceIndex();
256            assert result == true;
257            il.insertBeforeInstr(currentThreadCallInstr, mi.getCodeAttributeInfo());
258            result = il.advanceIndex();
259            assert result == true;
260            il.insertBeforeInstr(getThreadIDInstr, mi.getCodeAttributeInfo());
261            result = il.advanceIndex();
262            assert result == true;
263            il.insertBeforeInstr(addCallInstr, mi.getCodeAttributeInfo());
264            result = il.advanceIndex();
265            assert result == true;
266    
267            // insert ldc_w (edu.rice.cs.cunit.SyncPointBuffer)
268            instr[0] = Opcode.LDC_W;
269            Types.bytesFromShort((short)(sprClassIndex & 0xffff), instr, 1);
270            il.insertBeforeInstr(new GenericInstruction(instr), codeAttr);
271            result = il.advanceIndex();
272            assert result == true;
273    
274            // insert monitorexit
275            il.insertBeforeInstr(new GenericInstruction(new byte[]{Opcode.MONITOREXIT}), codeAttr);
276            result = il.advanceIndex();
277            assert result == true;
278    
279            // write code back
280            codeAttr.setCode(il.getCode());
281    
282            // make sure stack size is at least reqdStackSize
283            CodeAttributeInfo.CodeProperties codeProps = mi.getCodeAttributeInfo().getProperties();
284            codeProps.maxStack = (short)Math.max(7, codeProps.maxStack);
285            mi.getCodeAttributeInfo().setProperties(codeProps.maxStack, codeProps.maxLocals);
286        }
287    
288        /**
289         * Instrumentation of all classes is done.
290         */
291        public void done() {
292            // nothing to do
293        }
294    }