001    package edu.rice.cs.cunit.instrumentors.threadCheck;
002    
003    import edu.rice.cs.cunit.classFile.ClassFile;
004    import edu.rice.cs.cunit.classFile.ClassFileTools;
005    import edu.rice.cs.cunit.classFile.MethodInfo;
006    import edu.rice.cs.cunit.classFile.attributes.CodeAttributeInfo;
007    import edu.rice.cs.cunit.classFile.code.InstructionList;
008    import edu.rice.cs.cunit.classFile.code.Opcode;
009    import edu.rice.cs.cunit.classFile.code.instructions.GenericInstruction;
010    import edu.rice.cs.cunit.classFile.code.instructions.ReferenceInstruction;
011    import edu.rice.cs.cunit.classFile.constantPool.*;
012    import edu.rice.cs.cunit.classFile.constantPool.visitors.ADefaultPoolInfoVisitor;
013    import edu.rice.cs.cunit.classFile.constantPool.visitors.CheckMethodVisitor;
014    import edu.rice.cs.cunit.classFile.constantPool.visitors.CheckUTFVisitor;
015    import edu.rice.cs.cunit.threadCheck.OnlyRunBy;
016    import edu.rice.cs.cunit.threadCheck.ThreadCheck;
017    import edu.rice.cs.cunit.util.Types;
018    
019    import java.io.IOException;
020    import java.util.ArrayList;
021    import java.util.List;
022    
023    /**
024     * Instrumentor to add calls to ThreadCheck.checkCurrentThreadName/Id/Group to check if the current thread is not
025     * allowed to execute a class or method.
026     * <p/>
027     * This instrumentor checks for every method if there are @NotRunBy or @OnlyRunBy annotations attached to the method,
028     * the containing class, the same method in one of the superclasses or interfaces, or a superclass or interface, and
029     * then at the beginning of the method inserts calls to ThreadCheck..
030     *
031     * @author Mathias Ricken
032     */
033    public class AddThreadCheckStrategy extends AAddThreadCheckStrategy {
034        /**
035         * Constructor for this strategy.
036         * @param shared data shared among all AThreadCheckStrategy instances
037         * @param sharedAdd data for all AAddThreadCheckStrategy instances
038         */
039        public AddThreadCheckStrategy(SharedData shared, SharedAddData sharedAdd) {
040            this(new ArrayList<String>(), shared, sharedAdd);
041        }
042    
043        /**
044         * Constructor for this strategy.
045         * @param parameters parameters for the instrumentors
046         * @param shared data shared among all AThreadCheckStrategy instances
047         * @param sharedAdd data for all AAddThreadCheckStrategy instances
048         */
049        public AddThreadCheckStrategy(List<String> parameters, SharedData shared, SharedAddData sharedAdd) {
050            super(parameters, shared, sharedAdd);
051        }
052    
053        /**
054         * Instrument the class.
055         *
056         * @param cf class file info
057         */
058        public void instrument(final ClassFile cf) {
059            _sharedData.setCurrentClassName(cf.getThisClassName());
060            ConstantPool cp = cf.getConstantPool();
061            ReferenceInstruction checkCallInstr = new ReferenceInstruction(Opcode.INVOKESTATIC, (short)0);
062            ReferenceInstruction loadStringInstr = new ReferenceInstruction(Opcode.LDC_W, (short)0);
063            ReferenceInstruction loadLongInstr = new ReferenceInstruction(Opcode.LDC2_W, (short)0);
064            ReferenceInstruction okCheckCallInstr = new ReferenceInstruction(Opcode.INVOKESTATIC, (short)0);
065            ReferenceInstruction addNameCallInstr = new ReferenceInstruction(Opcode.INVOKESTATIC, (short)0);
066            ReferenceInstruction addIdCallInstr = new ReferenceInstruction(Opcode.INVOKESTATIC, (short)0);
067            ReferenceInstruction addGroupCallInstr = new ReferenceInstruction(Opcode.INVOKESTATIC, (short)0);
068            ReferenceInstruction setEventThreadCallInstr = new ReferenceInstruction(Opcode.INVOKESTATIC, (short)0);
069            ReferenceInstruction isDisplayableCallInstr = new ReferenceInstruction(Opcode.INVOKEVIRTUAL, (short)0);
070            int checkNameCallIndex = 0;
071            int checkIdCallIndex = 0;
072            int checkGroupCallIndex = 0;
073            int okCheckCallIndex = 0;
074            int addNameCallIndex = 0;
075            int addGroupCallIndex = 0;
076            int addIdCallIndex = 0;
077            int setEventThreadIndex = 0;
078            int isDisplayableCallIndex = 0;
079    
080            // process all methods in this class
081            for(MethodInfo mi : cf.getMethods()) {
082                // proces if not a native or abstract method, should have a body
083                if ((mi.getAccessFlags() & (ClassFile.ACC_NATIVE | ClassFile.ACC_ABSTRACT)) == 0) {
084                    long beginMillis = System.currentTimeMillis();
085                    ThreadCheckAnnotationRecord methodAR = getMethodAnnotations(cf, mi);
086                    long endMillis = System.currentTimeMillis();
087                    _sharedAddData.cacheInfo.addTimeSpent(endMillis-beginMillis);
088    
089                    if (!methodAR.empty()) {
090                        boolean changed = false;
091                        InstructionList il = new InstructionList(mi.getCodeAttributeInfo().getCode());
092                        
093                        // if this is a constructor ("<init>"), then we should perhaps wait until
094                        // after the this() or super() call; try to find it
095                        if (mi.getName().toString().equals("<init>")) {
096                            // check if this ctor calls other ctors
097                            boolean ctorCalled = false;
098                            do {
099                                if (il.getOpcode() == Opcode.INVOKESPECIAL) {
100                                    ReferenceInstruction ri = (ReferenceInstruction)il.getInstr();
101                                    short method = Types.shortFromBytes(ri.getBytecode(), 1);
102                                    MethodPoolInfo mpi = cf.getConstantPoolItem(method).execute(CheckMethodVisitor.singleton(), null);
103                                    if (mpi.getNameAndType().getName().toString().equals("<init>")) {
104                                        ClassFile curcf = cf;
105                                        while((curcf.getThisClassName()!=null) && (!curcf.getThisClassName().equals(""))) {
106                                            if (curcf.getThisClass().toString().equals(mpi.getClassInfo().getName().toString())) {
107                                                // super() call
108                                                ctorCalled = true;
109                                                break;
110                                            }
111                                            ClassFileTools.ClassLocation cl = null;
112                                            try {
113                                                cl = ClassFileTools.findClassFile(curcf.getSuperClassName(), _sharedData.getClassPath());
114                                                if (cl != null) {
115                                                    curcf = cl.getClassFile();
116                                                }
117                                                else {
118                                                    _sharedData.addClassNotFoundWarning(new ClassNotFoundWarning(curcf.getSuperClassName(),_sharedData.getCurrentClassName()));
119                                                    //Debug.out.println("Warning: Could not find " + curcf.getSuperClassName());
120                                                    break;
121                                                }
122                                            }
123                                            finally {
124                                                try { if (cl!=null) cl.close(); }
125                                                catch(IOException e) { /* ignore; shouldn't cause any problems except on Windows with read locks */ }
126                                            }
127                                        }
128                                        if (ctorCalled) { break; }
129                                    }
130                                }
131                            } while(il.advanceIndex());
132                            if (ctorCalled) {
133                                // super call found
134                                // advance one past the super call
135                                boolean res = il.advanceIndex();
136                                assert res == true;
137                            }
138                            else {
139                                // no this() or super() call found, start at the beginning
140                                il.setIndex(0);
141                                _sharedAddData.otherWarnings.add(new OnlyAfterRealizedWarning("ignored, no this() or super() call found in constructor "+
142                                                                                            cf.getThisClassName()+"."+mi.getName()+mi.getDescriptor()));
143                                methodAR.allowEventThread = OnlyRunBy.EVENT_THREAD.NO;
144                            }
145                        }
146                        
147                        // ========
148                        // NotRunBy
149                        // ========
150                        
151                        // add checks for thread names from NotRunBy
152                        for(String s: methodAR.denyThreadNames) {
153                            changed = true;
154                            
155                            if (checkNameCallIndex==0) {
156                                checkNameCallIndex = cf.addMethodToConstantPool(ThreadCheck.class.getName().replace('.','/'),
157                                                                                "checkCurrentThreadName",
158                                                                                "(Ljava/lang/String;)V");
159                            }
160                            checkCallInstr.setReference(checkNameCallIndex);
161                            il.insertInstr(checkCallInstr, mi.getCodeAttributeInfo());
162                            
163                            // add the string
164                            AUTFPoolInfo utfpi = new ASCIIPoolInfo(s, cp);
165                            int[] l = cf.addConstantPoolItems(new APoolInfo[]{utfpi});
166                            utfpi = cf.getConstantPoolItem(l[0]).execute(CheckUTFVisitor.singleton(), null);
167                            
168                            StringPoolInfo spi = new StringPoolInfo(utfpi, cp);
169                            l = cf.addConstantPoolItems(new APoolInfo[]{spi});
170                            spi = cf.getConstantPoolItem(l[0]).execute(new ADefaultPoolInfoVisitor<StringPoolInfo, Object>() {
171                                public StringPoolInfo defaultCase(APoolInfo host, Object o) {
172                                    throw new ClassFormatError("Info is of type " + host.getClass().getName() + ", needs to be StringPoolInfo");
173                                }
174                                public StringPoolInfo stringCase(StringPoolInfo host, Object o) {
175                                    return host;
176                                }
177                            }, null);
178                            
179                            loadStringInstr.setReference(l[0]);
180                            il.insertInstr(loadStringInstr, mi.getCodeAttributeInfo());
181                            
182                            // now we have inserted the following bytecode:
183                            // ldc_w (string)
184                            // invokestatic (checkCurrentThreadName)
185                        }
186                        
187                        // add checks for thread ids from NotRunBy
188                        for(long id: methodAR.denyThreadIds) {
189                            changed = true;
190                            if (checkIdCallIndex==0) {
191                                checkIdCallIndex = cf.addMethodToConstantPool(ThreadCheck.class.getName().replace('.','/'),
192                                                                              "checkCurrentThreadId",
193                                                                              "(J)V");
194                            }
195                            checkCallInstr.setReference(checkIdCallIndex);
196                            il.insertInstr(checkCallInstr, mi.getCodeAttributeInfo());
197                            
198                            // add the long
199                            loadLongInstr.setReference(cf.addLongToConstantPool(id));
200                            
201                            il.insertInstr(loadLongInstr, mi.getCodeAttributeInfo());
202                        }
203                        
204                        // add checks for thread groups from NotRunBy
205                        for(String s: methodAR.denyThreadGroups) {
206                            changed = true;
207                            if (checkGroupCallIndex==0) {
208                                checkGroupCallIndex = cf.addMethodToConstantPool(ThreadCheck.class.getName().replace('.','/'),
209                                                                                 "checkCurrentThreadGroup",
210                                                                                 "(Ljava/lang/String;)V");
211                            }
212                            checkCallInstr.setReference(checkGroupCallIndex);
213                            il.insertInstr(checkCallInstr, mi.getCodeAttributeInfo());
214                            
215                            // add the string
216                            AUTFPoolInfo utfpi = new ASCIIPoolInfo(s, cp);
217                            int[] l = cf.addConstantPoolItems(new APoolInfo[]{utfpi});
218                            utfpi = cf.getConstantPoolItem(l[0]).execute(CheckUTFVisitor.singleton(), null);
219                            
220                            StringPoolInfo spi = new StringPoolInfo(utfpi, cp);
221                            l = cf.addConstantPoolItems(new APoolInfo[]{spi});
222                            spi = cf.getConstantPoolItem(l[0]).execute(new ADefaultPoolInfoVisitor<StringPoolInfo, Object>() {
223                                public StringPoolInfo defaultCase(APoolInfo host, Object o) {
224                                    throw new ClassFormatError("Info is of type " + host.getClass().getName() + ", needs to be StringPoolInfo");
225                                }
226                                public StringPoolInfo stringCase(StringPoolInfo host, Object o) {
227                                    return host;
228                                }
229                            }, null);
230                            
231                            loadStringInstr.setReference(l[0]);
232                            il.insertInstr(loadStringInstr, mi.getCodeAttributeInfo());
233                        }
234                        
235                        // =========
236                        // OnlyRunBy
237                        // =========
238                        
239                        int lengthBeforeOnlyRunBy = il.getLength();
240                        
241                        // add checks for event thread from OnlyRunBy
242                        if (methodAR.allowEventThread!=OnlyRunBy.EVENT_THREAD.NO) {
243                            boolean insertInstr = true;
244                            if(methodAR.allowEventThread==OnlyRunBy.EVENT_THREAD.ONLY_AFTER_REALIZED) {
245                                if ((mi.getAccessFlags() & ClassFile.ACC_STATIC) != 0) {
246                                    _sharedAddData.otherWarnings.add(new OnlyAfterRealizedWarning(
247                                        "ignored for static method "+
248                                        cf.getThisClassName()+"."+mi.getName()+mi.getDescriptor()));
249                                    insertInstr = false;
250                                }
251                                else {
252                                    // we need to check if this is a subclass of java.awt.Component
253                                    boolean found = false;
254                                    ClassFile curcf = cf;
255                                    while((curcf.getThisClassName()!=null) && (!curcf.getThisClassName().equals(""))) {
256                                        if (curcf.getThisClassName().equals("java.awt.Component")) {
257                                            found = true;
258                                            break;
259                                        }
260                                        ClassFileTools.ClassLocation cl = null;
261                                        try {
262                                            cl = ClassFileTools.findClassFile(curcf.getSuperClassName(), _sharedData.getClassPath());
263                                            if (cl != null) {
264                                                curcf = cl.getClassFile();
265                                            }
266                                            else {
267                                                _sharedData.addClassNotFoundWarning(new ClassNotFoundWarning(curcf.getSuperClassName(),_sharedData.getCurrentClassName()));
268                                                //Debug.out.println("Warning: Could not find " + curcf.getSuperClassName());
269                                                break;
270                                            }
271                                        }
272                                        finally {
273                                            try { if (cl!=null) cl.close(); }
274                                            catch(IOException e) { /* ignore; shouldn't cause any problems except on Windows with read locks */ }
275                                        }
276                                    }
277                                    if (!found) {
278                                        _sharedAddData.otherWarnings.add(new OnlyAfterRealizedWarning(
279                                            "class not subclass of java.awt.Component, ignored for method "+
280                                            cf.getThisClassName()+"."+mi.getName()+mi.getDescriptor()));
281                                        insertInstr = false;
282                                    }
283                                }
284                            }
285                            
286                            if (insertInstr) {
287                                changed = true;
288                                
289                                if (setEventThreadIndex==0) {
290                                    setEventThreadIndex = cf.addMethodToConstantPool(ThreadCheck.class.getName().replace('.','/'),
291                                                                                     "setAllowedEventThread_OnlyRunBy",
292                                                                                     "(Z)V");
293                                }
294                                setEventThreadCallInstr.setReference(setEventThreadIndex);
295                                il.insertInstr(setEventThreadCallInstr, mi.getCodeAttributeInfo());
296                                
297                                if(methodAR.allowEventThread==OnlyRunBy.EVENT_THREAD.ONLY_AFTER_REALIZED) {
298                                    if (isDisplayableCallIndex==0) {
299                                        isDisplayableCallIndex = cf.addMethodToConstantPool("java/awt/Component",
300                                                                                            "isDisplayable",
301                                                                                            "()Z");
302                                    }
303                                    isDisplayableCallInstr.setReference(isDisplayableCallIndex);
304                                    il.insertInstr(isDisplayableCallInstr, mi.getCodeAttributeInfo());
305                                    il.insertInstr(new GenericInstruction(Opcode.ALOAD_0), mi.getCodeAttributeInfo());
306                                }
307                                else {
308                                    il.insertInstr(new GenericInstruction(Opcode.ICONST_1), mi.getCodeAttributeInfo());
309                                }
310                            }
311                        }
312                        
313                        // add checks for thread groups from OnlyRunBy
314                        for(String s: methodAR.allowThreadNames) {
315                            changed = true;
316                            
317                            if (addNameCallIndex==0) {
318                                addNameCallIndex = cf.addMethodToConstantPool(ThreadCheck.class.getName().replace('.','/'),
319                                                                              "addAllowedName_OnlyRunBy",
320                                                                              "(Ljava/lang/String;)V");
321                            }
322                            addNameCallInstr.setReference(addNameCallIndex);
323                            il.insertInstr(addNameCallInstr, mi.getCodeAttributeInfo());
324                            
325                            // add the string
326                            AUTFPoolInfo utfpi = new ASCIIPoolInfo(s, cp);
327                            int[] l = cf.addConstantPoolItems(new APoolInfo[]{utfpi});
328                            utfpi = cf.getConstantPoolItem(l[0]).execute(CheckUTFVisitor.singleton(), null);
329                            
330                            StringPoolInfo spi = new StringPoolInfo(utfpi, cp);
331                            l = cf.addConstantPoolItems(new APoolInfo[]{spi});
332                            spi = cf.getConstantPoolItem(l[0]).execute(new ADefaultPoolInfoVisitor<StringPoolInfo, Object>() {
333                                public StringPoolInfo defaultCase(APoolInfo host, Object o) {
334                                    throw new ClassFormatError("Info is of type " + host.getClass().getName() + ", needs to be StringPoolInfo");
335                                }
336                                public StringPoolInfo stringCase(StringPoolInfo host, Object o) {
337                                    return host;
338                                }
339                            }, null);
340                            
341                            loadStringInstr.setReference(l[0]);
342                            il.insertInstr(loadStringInstr, mi.getCodeAttributeInfo());
343                            
344                            // now we have inserted the following bytecode:
345                            // ldc_w (string)
346                            // invokestatic (addAllowedName_OnlyRunBy)
347                        }
348                        
349                        // add checks for thread ids from OnlyRunBy
350                        for(long id: methodAR.allowThreadIds) {
351                            changed = true;
352                            
353                            if (addIdCallIndex==0) {
354                                addIdCallIndex = cf.addMethodToConstantPool(ThreadCheck.class.getName().replace('.','/'),
355                                                                            "addAllowedId_OnlyRunBy",
356                                                                            "(J)V");
357                            }
358                            addIdCallInstr.setReference(addIdCallIndex);
359                            il.insertInstr(addIdCallInstr, mi.getCodeAttributeInfo());
360                            
361                            // add the long
362                            loadLongInstr.setReference(cf.addLongToConstantPool(id));
363                            
364                            il.insertInstr(loadLongInstr, mi.getCodeAttributeInfo());
365                            
366                            // now we have inserted the following bytecode:
367                            // ldc2_w (id)
368                            // invokestatic (addAllowedId_OnlyRunBy)
369                        }
370                        
371                        // add checks for thread groups from OnlyRunBy
372                        for(String s: methodAR.allowThreadGroups) {
373                            changed = true;
374                            
375                            if (addGroupCallIndex==0) {
376                                addGroupCallIndex = cf.addMethodToConstantPool(ThreadCheck.class.getName().replace('.','/'),
377                                                                               "addAllowedGroup_OnlyRunBy",
378                                                                               "(Ljava/lang/String;)V");
379                            }
380                            addGroupCallInstr.setReference(addGroupCallIndex);
381                            il.insertInstr(addGroupCallInstr, mi.getCodeAttributeInfo());
382                            
383                            // add the string
384                            AUTFPoolInfo utfpi = new ASCIIPoolInfo(s, cp);
385                            int[] l = cf.addConstantPoolItems(new APoolInfo[]{utfpi});
386                            utfpi = cf.getConstantPoolItem(l[0]).execute(CheckUTFVisitor.singleton(), null);
387                            
388                            StringPoolInfo spi = new StringPoolInfo(utfpi, cp);
389                            l = cf.addConstantPoolItems(new APoolInfo[]{spi});
390                            spi = cf.getConstantPoolItem(l[0]).execute(new ADefaultPoolInfoVisitor<StringPoolInfo, Object>() {
391                                public StringPoolInfo defaultCase(APoolInfo host, Object o) {
392                                    throw new ClassFormatError("Info is of type " + host.getClass().getName() + ", needs to be StringPoolInfo");
393                                }
394                                public StringPoolInfo stringCase(StringPoolInfo host, Object o) {
395                                    return host;
396                                }
397                            }, null);
398                            
399                            loadStringInstr.setReference(l[0]);
400                            il.insertInstr(loadStringInstr, mi.getCodeAttributeInfo());
401                            
402                            // now we have inserted the following bytecode:
403                            // ldc_w (string)
404                            // invokestatic (addAllowedName_OnlyRunBy)
405                        }
406                        
407                        int lengthAfterOnlyRunBy = il.getLength();
408                        if (lengthAfterOnlyRunBy>lengthBeforeOnlyRunBy) {
409                            // add check call for OnlyRunBy
410                            if (okCheckCallIndex==0) {
411                                okCheckCallIndex = cf.addMethodToConstantPool(ThreadCheck.class.getName().replace('.','/'),
412                                                                              "checkCurrentThread_OnlyRunBy",
413                                                                              "()V");
414                            }
415                            okCheckCallInstr.setReference(okCheckCallIndex);
416                            il.advanceIndex(lengthAfterOnlyRunBy-lengthBeforeOnlyRunBy);
417                            il.insertInstr(okCheckCallInstr, mi.getCodeAttributeInfo());
418                            // invokestatic (checkCurrentThread_OnlyRunBy)
419                        }
420    
421                        if (changed) {
422                            // write code back
423                            mi.getCodeAttributeInfo().setCode(il.getCode());
424                            
425                            // make sure we have at least five slots on the stack
426                            CodeAttributeInfo.CodeProperties cProps = mi.getCodeAttributeInfo().getProperties();
427                            cProps.maxStack = Math.max(5, cProps.maxStack);
428                            mi.getCodeAttributeInfo().setProperties(cProps.maxStack, cProps.maxLocals);
429                        }
430                    }
431                }
432            }
433        }
434    }