001    /*
002      Copyright (C) 2001-2003 Laurent Martelli <laurent@aopsys.com>
003    
004      This program is free software; you can redistribute it and/or modify
005      it under the terms of the GNU Lesser General Public License as
006      published by the Free Software Foundation; either version 2 of the
007      License, or (at your option) any later version.
008    
009      This program is distributed in the hope that it will be useful,
010      but WITHOUT ANY WARRANTY; without even the implied warranty of
011      MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
012      GNU Lesser General Public License for more details.
013    
014      You should have received a copy of the GNU Lesser General Public License
015      along with this program; if not, write to the Free Software
016      Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA */
017    
018    package org.objectweb.jac.aspects.persistence;
019    
020    import java.sql.Connection;
021    import java.sql.ResultSet;
022    import java.sql.SQLException;
023    import java.util.Arrays;
024    import java.util.Collection;
025    import java.util.HashMap;
026    import java.util.Hashtable;
027    import java.util.Iterator;
028    import java.util.List;
029    import java.util.Map;
030    import java.util.Vector;
031    import org.apache.log4j.Logger;
032    import org.objectweb.jac.core.rtti.ClassItem;
033    import org.objectweb.jac.core.rtti.CollectionItem;
034    import org.objectweb.jac.core.rtti.FieldItem;
035    import org.objectweb.jac.util.Strings;
036    
037    /**
038     * Implements the storage to store within an SQL compliant database system.
039     *
040     * @see LongOID
041     */
042    
043    public abstract class SQLStorage implements Storage,java.io.Serializable {
044        static Logger logger = Logger.getLogger("persistence.storage");
045        static Logger loggerSql = Logger.getLogger("persistence.sql");
046    
047        /**
048         * The SQL connection to the database that is use by this storage. */
049        protected Connection db;
050    
051        /**
052         * Default constructor. */
053        protected SQLStorage(PersistenceAC ac) throws SQLException {
054            this. ac = ac;
055        }
056    
057        /**
058         * Creates a new SQL storage.<p>
059         *
060         * @param db the connection to the database
061         */
062        public SQLStorage(PersistenceAC ac, Connection db) throws SQLException {
063            this. ac = ac;
064            setConnection(db);
065        }
066    
067        protected String id;
068        public String getId() {
069            return id;
070        }
071        public void setId(String id) {
072            this.id = id;
073        }
074    
075        public void close() {}
076    
077        /**
078         * Resets the connection to the database with the given
079         * connection.<p>
080         * 
081         * @param db the new connection */
082    
083        protected void setConnection(Connection db) throws SQLException {
084            this.db = db;
085            updateJacNames();
086        }
087    
088        /**
089         * Tells wether a table with a given name exists
090         */
091        protected abstract boolean hasTable(String name) throws Exception;
092    
093        /**
094         * Updates jac names from <classname><count> to <classname>#<count>
095         */
096        protected void updateJacNames() throws SQLException {
097            ResultSet rs = executeQuery(
098                "SELECT roots.id, roots.name, classes.classid from roots,classes where "+
099                "not roots.name like '%#%' and roots.id=classes.id");
100            execute("BEGIN TRANSACTION");
101            try {
102                while (rs.next()) {
103                    String classname = Strings.getShortClassName(rs.getString("classid")).toLowerCase();
104                    String name = rs.getString("name");
105                    if (name.startsWith(classname) && 
106                        name.length()>classname.length() &&
107                        name.charAt(classname.length())!='#') 
108                    {
109                        String newName = classname+"#"+name.substring(classname.length());
110                        executeUpdate("update roots set name='"+newName+"' where name='"+name+"'");
111                    }
112                }
113                execute("COMMIT");
114            } catch (Exception e) {
115                execute("ROLLBACK");
116                logger.error("Failed to update jac names");
117            }
118        }    
119    
120        protected int executeUpdate(String query) throws SQLException {
121            try {
122                loggerSql.debug(query);
123                return db.createStatement().executeUpdate(query);
124            } catch (SQLException e) {
125                logger.error("executeUpdate query failed: "+query); 
126                throw e;
127            }
128        }
129    
130        protected ResultSet executeQuery(String query) throws SQLException {
131            try {
132                loggerSql.debug(query);
133                return db.createStatement().executeQuery(query);
134            } catch (SQLException e) {
135                logger.error("executeQuery query failed: "+query);
136                throw e;
137            }
138        }
139    
140        protected boolean execute(String query) throws SQLException {
141            try {
142                loggerSql.debug(query);
143                return db.createStatement().execute(query);
144            } catch (SQLException e) {
145                logger.error("execute query failed: "+query);
146                throw e;
147            }
148        }
149    
150        protected boolean executeSilent(String query) throws SQLException {
151            return db.createStatement().execute(query);
152        }
153    
154        public void deleteObject(OID oid) throws Exception
155        {
156            logger.debug("deleteObject("+oid+")");
157            executeUpdate("delete from objects where id="+oid.localId());
158            executeUpdate("delete from roots where id="+oid.localId());
159            // WE SHOULD ALSO REMOVE COLLECTIONS
160        }
161    
162        public void setField(OID oid, FieldItem field, Object object) 
163            throws Exception
164        {
165            logger.debug("setField("+oid+","+field.getName()+","+object+")");
166            String value = ValueConverter.objectToString(this,object);
167            String fieldID = field.getName();
168            String query = "insert into objects (id,fieldID,value) values "+
169                "("+oid.localId()+",'"+fieldID+"','"+addSlashes(value)+"')";
170            if (executeUpdate(query)==0) {
171                logger.error("setField failed : "+oid+","+fieldID+","+value);
172            }
173        }
174    
175        public void updateField(OID oid, FieldItem field, Object object) 
176            throws Exception
177        {
178            logger.debug("updateField("+oid+","+field+","+object+")");
179            String fieldID = field.getName();
180            String value = ValueConverter.objectToString(this,object);
181            String query = "update objects set value='"+addSlashes(value)+"' where "+
182                "id="+oid.localId()+" and fieldID='"+fieldID+"'";
183            if (executeUpdate(query)==0) {
184                setField(oid,field,object);
185            }
186        }
187    
188        public Object getField(OID oid, FieldItem field) 
189            throws Exception
190        {
191            logger.debug("getField("+oid+","+field.getName()+")");
192            checkStorage();
193            String fieldID = field.getName();
194            ResultSet rs = executeQuery("select value from objects where id="+oid.localId()+
195                                        " and fieldID='"+fieldID+"'");
196            if (rs.next()) {
197                return ValueConverter.stringToObject(this,rs.getString("value"));
198            } else {
199                if (field.isPrimitive()) {
200                    logger.warn("no such field in storage "+oid+","+fieldID);
201                }
202                return null;
203            }
204        }
205    
206        public StorageField[] getFields(OID oid, ClassItem cl, FieldItem[] fields) 
207            throws Exception
208        {
209            logger.debug("getFields "+oid+","+cl+","+Arrays.asList(fields));
210            // compute stringified list of fields for SQL query
211            if (fields.length == 0) {
212                return new StorageField[0];
213            }
214            String fieldlist = "(";
215            boolean first = true;
216            for (int i=0; i<fields.length; i++) {
217                if (!fields[i].isCalculated() && !fields[i].isTransient()) {
218                    if (!first)
219                        fieldlist += " or ";
220                    fieldlist += "fieldID='"+fields[i].getName()+"'";
221                    first = false;
222                }
223            }
224            fieldlist += ")";
225          
226            StorageField fieldValues[] = new StorageField[fields.length];
227            String query = "select * from objects where id="+oid.localId()+" and "+fieldlist;
228            ResultSet rs = executeQuery(query);
229    
230            int i=0;
231            while (rs.next()) {
232                FieldItem field = cl.getField(rs.getString("fieldID"));
233                fieldValues[i] = new StorageField(
234                    cl,field,
235                    ValueConverter.stringToObject(this,rs.getString("value")));
236                i++;
237            }
238            return fieldValues;
239        }
240    
241        public void removeField(OID oid, FieldItem field, Object value) 
242            throws Exception
243        {
244            logger.debug("removeField("+oid+","+field+","+value+")");
245            String fieldID = field.getName();
246            executeUpdate("delete from objects where id="+oid.localId()+
247                          " and fieldID='"+fieldID+"'");      
248        }
249    
250        public Collection getRootObjects() throws Exception {
251            logger.debug("getRootObjects");
252            String sql = "select id from roots";
253            ResultSet rs = executeQuery(sql);
254            Vector result = new Vector();
255            while (rs.next()) {
256                result.add(new LongOID(this,rs.getLong("value")));
257            }
258            logger.debug("getRootObjects returns " + result);
259            return result;
260        }
261    
262        // Collection methods
263    
264        public OID getCollectionID(OID oid, CollectionItem collection) 
265            throws Exception 
266        {
267            return getOID("select value from objects where "+
268                          "id="+oid.localId()+" and fieldID='"+collection.getName()+"'");
269        }
270    
271        public List getCollectionValues(OID oid, CollectionItem collection, 
272                                        String table, String orderBy) 
273            throws Exception
274        {
275            logger.debug("getCollectionValues("+oid+","+collection+")");
276            String fieldID = collection.getName();
277    
278            String sql = "select "+table+".value from "+table+",objects where "+
279                "objects.id="+oid.localId()+
280                " and objects.fieldID='"+fieldID+"'"+
281                " and objects.value="+table+".id";
282            if ( orderBy!=null) {
283                sql += " order by " + orderBy;
284            }
285          
286            ResultSet rs = executeQuery(sql);
287            Vector result = new Vector();
288            while (rs.next()) {
289                result.add(ValueConverter.stringToObject(this,rs.getString("value")));
290            }
291            logger.debug("getCollectionValues returns " + result);
292            return result;
293        }
294    
295        public boolean collectionContains(String table, OID cid, Object value) 
296            throws Exception 
297        {
298            ResultSet res = executeQuery(
299                "select id from "+table+" where "+
300                "id="+cid.localId()+" and value='"+
301                addSlashes(ValueConverter.objectToString(this,value))+"'");
302            return res.next();
303        }
304    
305        // List methods
306    
307        public void clearList(OID cid) 
308            throws Exception 
309        {
310            logger.debug("clearList("+cid+")");
311            executeUpdate("delete from lists where id="+cid.localId());
312        }
313    
314        public List getList(OID oid, CollectionItem collection) 
315            throws Exception
316        {
317            return getList(getCollectionID(oid,collection));
318        }
319    
320        public List getList(OID cid)
321            throws Exception
322        {
323            logger.debug("getList("+cid+")");
324            ResultSet rs = executeQuery("select value from lists "+
325                                        "where id="+cid.localId()+" order by index");
326            Vector result = new Vector();
327            while (rs.next()) {
328                result.add(
329                    ValueConverter.stringToObject(this,rs.getString("value")));
330            }
331            logger.debug("getList returns " + result);
332            return result;
333        }
334    
335        public long getListSize(OID cid) 
336            throws Exception
337        {
338            return getLong("select count(*) from lists where id="+cid.localId());
339        }
340    
341        public boolean listContains(OID cid, Object value) 
342            throws Exception 
343        {
344            return collectionContains("lists",cid,value);
345        }
346    
347        public Object getListItem(OID cid, long index)
348            throws Exception
349        {
350            ResultSet rs = 
351                executeQuery("select value from lists where "+
352                             "id="+cid.localId()+" order by index limit 1 offset "+index);
353            if (rs.next()) {
354                return ValueConverter.stringToObject(this,rs.getString("value"));
355            } else {
356                return null;      
357            }
358        }
359    
360        public long getIndexInList(OID cid, Object value)
361            throws Exception
362        {
363            ResultSet rs = executeQuery(
364                "select min(index) as index from lists where "+
365                "id="+cid.localId()+" and value='"+
366                addSlashes(ValueConverter.objectToString(this,value))+"'");
367            if (rs.next()) {
368                long index = rs.getLong(1);
369                return getLong("select count(*) from lists where id="+cid.localId()+
370                               " and index<="+index)-1;
371            } else {
372                return -1;
373            }
374        }
375    
376        protected long getInternalIndexInList(OID cid, Object value)
377            throws Exception
378        {
379            ResultSet rs = executeQuery(
380                "select min(index) as index from lists where "+
381                "id="+cid.localId()+" and value='"+
382                addSlashes(ValueConverter.objectToString(this,value))+"'");
383            if (rs.next()) {
384                long result = rs.getLong(1);
385                if (rs.wasNull())
386                    return -1;
387                else
388                    return result;
389            } else {
390                return -1;
391            }
392        }
393    
394        public long getLastIndexInList(OID cid, Object value)
395            throws Exception
396        {
397            ResultSet rs = executeQuery(
398                "select max(index) from lists where "+
399                "id="+cid.localId()+" and value='"+
400                addSlashes(ValueConverter.objectToString(this,value))+"'");
401            if (rs.next()) {
402                long index = rs.getLong(1);
403                return getLong("select count(*) from lists where id="+cid.localId()+
404                               " and index<="+index)-1;
405            } else {
406                return -1;      
407            }
408        }
409    
410        public void addToList(OID cid, long position, Object value)
411            throws Exception
412        {
413            logger.debug("addToList("+cid+","+position+","+value+")");
414            executeUpdate("update lists set index=index+1 where id="+cid.localId()+
415                          " and index>="+position);
416            executeUpdate(
417                "insert into lists (id,index,value) values "+
418                "("+cid.localId()+","+position+",'"+
419                addSlashes(ValueConverter.objectToString(this,value))+"')");
420        }
421    
422        public void addToList(OID cid, Object value)
423            throws Exception
424        {
425            logger.debug("addToList("+cid+","+value+")");
426            long size = getListSize(cid);
427            String indexExpr;
428            if (size==0)
429                indexExpr = "0";
430            else
431                indexExpr = "select max(index)+1 from lists where id="+cid.localId();
432            executeUpdate(
433                "insert into lists (id,index,value) values "+
434                "("+cid.localId()+",("+indexExpr+"),'"+
435                addSlashes(ValueConverter.objectToString(this,value))+"')");
436        }
437    
438        public void setListItem(OID cid, long index, Object value)
439            throws Exception
440        {
441            logger.debug("setListItem("+cid+","+index+","+value+")");
442            executeUpdate("update lists set value='"+
443                          addSlashes(ValueConverter.objectToString(this,value))+
444                          " where id="+cid.localId()+" and index="+index);
445        }
446    
447        public void removeFromList(OID cid, long position)
448            throws Exception
449        {
450            logger.debug("removeFromList("+cid+","+position+")");
451            // First, get the index for the position
452            ResultSet rs = executeQuery("select index from lists where "+
453                                        "id="+cid.localId()+" order by index limit 1,"+position);
454            long index = rs.getLong("index");
455            executeUpdate("delete from lists where "+"id="+cid.localId()+" and index="+index);
456        }
457    
458        public void removeFromList(OID cid, Object value)
459            throws Exception
460        {
461            logger.debug("removeFromList("+cid+","+value+")");
462            long index = getInternalIndexInList(cid,value);
463            executeUpdate("delete from lists where "+"id="+cid.localId()+" and index="+index);
464        }
465    
466        // Set methods
467    
468        public void clearSet(OID cid) 
469            throws Exception 
470        {
471            logger.debug("clearSet("+cid+")");
472            executeUpdate("delete from sets where id="+cid.localId());
473        }
474    
475        public List getSet(OID oid, CollectionItem collection) throws Exception {
476            return getSet(getCollectionID(oid,collection));
477        }   
478    
479        public List getSet(OID cid) throws Exception {
480            logger.debug("getSet("+cid+")");
481            ResultSet rs = executeQuery("select value from sets "+
482                                        "where id="+cid.localId());
483            Vector result = new Vector();
484            while (rs.next()) {
485                result.add(
486                    ValueConverter.stringToObject(this,rs.getString("value")));
487            }
488            logger.debug("getSet returns " + result);
489            return result;
490        }   
491    
492        public long getSetSize(OID cid)
493            throws Exception
494        {
495            return getLong("select count(*) from sets where id="+cid.localId());
496        }
497    
498        public boolean setContains(OID cid, Object value) 
499            throws Exception 
500        {
501            return collectionContains("sets",cid,value);
502        }
503    
504        public boolean addToSet(OID cid, Object value) 
505            throws Exception 
506        {
507            logger.debug("addToSet("+cid+","+value+")");
508            if (!collectionContains("sets",cid,value)) {
509                executeUpdate(
510                    "insert into sets (id,value) values "+"("+cid.localId()+",'"+
511                    addSlashes(ValueConverter.objectToString(this,value))+"')");
512                return true;
513            } else {
514                return false;
515            }
516        }
517    
518        public boolean removeFromSet(OID cid, Object value) 
519            throws Exception 
520        {
521            logger.debug("removeFromSet("+cid+","+value+")");
522            boolean result = collectionContains("set",cid,value);
523            if (result)
524                executeUpdate(
525                    "delete from sets where "+
526                    "id="+cid.localId()+" and value='"+
527                    addSlashes(ValueConverter.objectToString(this,value))+"'");
528            return result;
529        }
530    
531        // Map functions
532    
533        public void clearMap(OID cid) 
534            throws Exception 
535        {
536            logger.debug("clearMap("+cid+")");
537            executeUpdate("delete from maps where id="+cid.localId());
538        }
539    
540        public Map getMap(OID oid, CollectionItem collection) 
541            throws Exception
542        {
543            return getMap(getCollectionID(oid,collection));
544        }
545    
546        public Map getMap(OID cid) throws Exception
547        {
548            logger.debug("getMap("+cid+")");
549            ResultSet rs =
550                executeQuery("select value,key from maps where id="+cid.localId());
551            Map result = new HashMap();
552            while (rs.next()) {
553                result.put(
554                    ValueConverter.stringToObject(this,rs.getString("key")),
555                    ValueConverter.stringToObject(this,rs.getString("value")));
556            }
557            logger.debug("getMap returns " + result);
558            return result;
559        }
560    
561        public long getMapSize(OID cid)
562            throws Exception
563        {
564            return getLong("select count(*) from maps where id="+cid.localId());
565        }
566    
567        public Object putInMap(OID cid, Object key, Object value) 
568            throws Exception
569        {
570            logger.debug("putInMap("+cid+","+key+"->"+value+")");
571            if (mapContainsKey(cid,key)) {
572                Object old = getFromMap(cid,key);
573                executeUpdate(
574                    "update maps set "+
575                    "key='"+addSlashes(ValueConverter.objectToString(this,key))+"',"+
576                    "value='"+addSlashes(ValueConverter.objectToString(this,value))+"' "+
577                    "where id="+cid.localId()+
578                    " and key='"+addSlashes(ValueConverter.objectToString(this,key))+"'");
579                return old;
580            } else {
581                executeUpdate(
582                    "insert into maps (id,key,value) values "+
583                    "("+cid.localId()+",'"+
584                    addSlashes(ValueConverter.objectToString(this,key))+
585                    "','"+addSlashes(ValueConverter.objectToString(this,value))+"')");
586                return null;
587            }
588        }
589    
590        public Object getFromMap(OID cid, Object key) 
591            throws Exception
592        {
593            logger.debug("getFromMap("+cid+","+key+")");
594            ResultSet res = executeQuery(
595                "select value from maps where "+
596                "id="+cid.localId()+" and key='"+
597                addSlashes(ValueConverter.objectToString(this,key))+"'");
598            if (res.next()) {
599                return ValueConverter.stringToObject(this,res.getString("value"));
600            } else {
601                return null;
602            }
603        }
604    
605        public boolean mapContainsKey(OID cid, Object key) 
606            throws Exception
607        {
608            logger.debug("mapContainsKey("+cid+","+key+")");
609            ResultSet res = executeQuery(
610                "select value from maps where "+
611                "id="+cid.localId()+" and key='"+
612                addSlashes(ValueConverter.objectToString(this,key))+"'");
613            return res.next();
614        }
615    
616        public boolean mapContainsValue(OID cid, Object value) 
617            throws Exception 
618        {
619            return collectionContains("maps",cid,value);
620        }
621    
622        public Object removeFromMap(OID cid, Object key)
623            throws Exception
624    
625        {
626            logger.debug("removeFromMap("+cid+","+key+")");
627            if (!mapContainsKey(cid,key)) {
628                return null;
629            } else {
630                Object result = getFromMap(cid,key);
631                executeUpdate(
632                    "delete from maps where "+
633                    "id="+cid.localId()+" and key='"+
634                    addSlashes(ValueConverter.objectToString(this,key))+"'");
635                return result;
636            }
637        }
638    
639        public abstract String newName(String className) throws Exception;
640    
641    
642        public Map getNameCounters() {
643            // TODO
644            return new Hashtable();
645        }
646    
647        public abstract void updateNameCounters(Map counters) throws Exception;
648    
649        public OID getOIDFromName(String name) 
650            throws Exception
651        {
652            ResultSet rs = executeQuery(
653                "select id from roots where name='"+name+"'");
654            if (rs.next()) {
655                return new LongOID(this,rs.getLong("id"));
656            } else {
657                return null;
658            }
659        }
660       
661        public String getNameFromOID(OID oid) throws Exception
662        {
663            ResultSet rs = executeQuery(
664                "select name from roots where id="+oid.localId());
665            if (rs.next()) {
666                return rs.getString("name");
667            } else {
668                return null;
669            }
670        }
671    
672        public void bindOIDToName(OID oid, String name) throws Exception
673        {
674            logger.debug("bindOIDToName "+oid+" -> "+name);
675            executeUpdate("insert into roots (id,name) values ("+
676                          oid.localId()+",'"+name+"')");
677        }
678    
679        public void deleteName(String name) throws Exception
680        {
681            logger.debug("deleteName("+name+")");
682            executeUpdate("delete from roots where name='"+name+"'");
683        }
684    
685        public String getClassID(OID oid) throws Exception
686        {
687            ResultSet rs = executeQuery("select classid from classes where id="+oid.localId());
688            if (rs.next()) {
689                String classID = rs.getString("classid");
690                logger.debug("getClassID("+oid+") -> "+classID);
691                return classID;
692            } else {
693                throw new NoSuchOIDError(oid);
694            }
695        }
696    
697        public Collection getObjects(ClassItem cl) throws Exception
698        {
699            logger.debug("getObjects("+cl.getName()+")");
700            Vector result = new Vector();
701            getObjects(cl,result);
702            return result;
703        }
704    
705        protected void getObjects(ClassItem cl, Vector objects) throws SQLException {
706            String query = "select id from classes";
707            if (cl != null) {
708                query += " where classes.classid='"+cl.getName()+"'";
709            }
710                   
711            ResultSet rs = executeQuery(query);
712            while (rs.next()) {
713                objects.add(new LongOID(this,rs.getLong("id")));
714            }
715    
716            Iterator i = cl.getChildren().iterator();
717            while(i.hasNext()) {
718                ClassItem subclass = (ClassItem)i.next();
719                getObjects(subclass,objects);
720            }
721        }
722    
723        public void startTransaction() throws SQLException {
724            execute("BEGIN TRANSACTION");
725        }
726    
727        public void commit() throws SQLException {
728            execute("COMMIT");
729        }
730    
731        public void rollback() throws SQLException {
732            execute("ROLLBACK");
733        }
734    
735        /**
736         * Throw an exception if storage is null or invalid
737         *
738         */
739        protected void checkStorage() {
740            if (db==null) {
741                logger.error("connection is NULL");
742                throw new InvalidStorageException("connection is NULL");
743            }
744        }
745    
746        /**
747         * Creates a new object using a PostgreSQL sequance.<p>
748         *
749         * @return the new object OID
750         */
751        public OID createObject(String className) throws Exception {
752            LongOID res = new LongOID(this,getNextVal("object_id"));
753            executeUpdate("insert into classes (id,classid) values ("+
754                          res.localId()+",'"+className+"')");
755            return res;
756        }
757    
758        /**
759         * Returns the next value of a sequence
760         */
761        public abstract long getNextVal(String sequence) throws Exception;
762    
763        public long getLong(String query) throws Exception {
764            ResultSet rs = executeQuery(query);
765            rs.next();
766            return rs.getLong(1);
767        }
768    
769        public OID getOID(String query) throws Exception {
770            return new LongOID(this,getLong(query));
771        }
772    
773        public static class InvalidStorageException extends RuntimeException {
774            public InvalidStorageException(String msg) { 
775                super(msg); 
776            }
777        }
778    
779        /* add '\' before ''' and '\' */
780        public static String addSlashes(String str) {
781            StringBuffer res = new StringBuffer(str.length());
782            for (int i=0; i<str.length();i++) {
783                if (str.charAt(i)=='\'' || str.charAt(i)=='\\') {
784                    res.append('\\');
785                }
786                res.append(str.charAt(i));
787            }
788            return res.toString();
789        }
790    
791        PersistenceAC ac;
792    }