001/*
002 * HA-JDBC: High-Availability JDBC
003 * Copyright (c) 2004-2007 Paul Ferraro
004 * 
005 * This library is free software; you can redistribute it and/or modify it 
006 * under the terms of the GNU Lesser General Public License as published by the 
007 * Free Software Foundation; either version 2.1 of the License, or (at your 
008 * option) any later version.
009 * 
010 * This library is distributed in the hope that it will be useful, but WITHOUT
011 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 
012 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License 
013 * for more details.
014 * 
015 * You should have received a copy of the GNU Lesser General Public License
016 * along with this library; if not, write to the Free Software Foundation, 
017 * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
018 * 
019 * Contact: ferraro@users.sourceforge.net
020 */
021package net.sf.hajdbc.sync;
022
023import java.sql.Connection;
024import java.sql.ResultSet;
025import java.sql.SQLException;
026import java.sql.Statement;
027import java.sql.Types;
028import java.text.MessageFormat;
029import java.util.Collection;
030import java.util.HashMap;
031import java.util.Map;
032import java.util.Set;
033import java.util.concurrent.Callable;
034import java.util.concurrent.ExecutionException;
035import java.util.concurrent.ExecutorService;
036import java.util.concurrent.Future;
037
038import net.sf.hajdbc.Database;
039import net.sf.hajdbc.Dialect;
040import net.sf.hajdbc.ForeignKeyConstraint;
041import net.sf.hajdbc.Messages;
042import net.sf.hajdbc.SequenceProperties;
043import net.sf.hajdbc.SynchronizationContext;
044import net.sf.hajdbc.TableProperties;
045import net.sf.hajdbc.UniqueConstraint;
046import net.sf.hajdbc.util.SQLExceptionFactory;
047import net.sf.hajdbc.util.Strings;
048
049import org.slf4j.Logger;
050import org.slf4j.LoggerFactory;
051
052/**
053 * @author Paul Ferraro
054 *
055 */
056public final class SynchronizationSupport
057{
058        private static Logger logger = LoggerFactory.getLogger(SynchronizationSupport.class);
059        
060        private SynchronizationSupport()
061        {
062                // Hide
063        }
064        
065        /**
066         * Drop all foreign key constraints on the target database
067         * @param <D> 
068         * @param context a synchronization context
069         * @throws SQLException if database error occurs
070         */
071        public static <D> void dropForeignKeys(SynchronizationContext<D> context) throws SQLException
072        {
073                Dialect dialect = context.getDialect();
074                
075                Connection connection = context.getConnection(context.getTargetDatabase());
076                
077                Statement statement = connection.createStatement();
078                
079                for (TableProperties table: context.getTargetDatabaseProperties().getTables())
080                {
081                        for (ForeignKeyConstraint constraint: table.getForeignKeyConstraints())
082                        {
083                                String sql = dialect.getDropForeignKeyConstraintSQL(constraint);
084                                
085                                logger.debug(sql);
086                                
087                                statement.addBatch(sql);
088                        }
089                }
090                
091                statement.executeBatch();
092                statement.close();
093        }
094        
095        /**
096         * Restores all foreign key constraints on the target database
097         * @param <D> 
098         * @param context a synchronization context
099         * @throws SQLException if database error occurs
100         */
101        public static <D> void restoreForeignKeys(SynchronizationContext<D> context) throws SQLException
102        {
103                Dialect dialect = context.getDialect();
104                
105                Connection connection = context.getConnection(context.getTargetDatabase());
106                
107                Statement statement = connection.createStatement();
108                
109                for (TableProperties table: context.getSourceDatabaseProperties().getTables())
110                {
111                        for (ForeignKeyConstraint constraint: table.getForeignKeyConstraints())
112                        {
113                                String sql = dialect.getCreateForeignKeyConstraintSQL(constraint);
114                                
115                                logger.debug(sql);
116                                
117                                statement.addBatch(sql);
118                        }
119                }
120                
121                statement.executeBatch();
122                statement.close();
123        }
124        
125        /**
126         * Synchronizes the sequences on the target database with the source database.
127         * @param <D> 
128         * @param context a synchronization context
129         * @throws SQLException if database error occurs
130         */
131        public static <D> void synchronizeSequences(final SynchronizationContext<D> context) throws SQLException
132        {
133                Collection<SequenceProperties> sequences = context.getSourceDatabaseProperties().getSequences();
134
135                if (!sequences.isEmpty())
136                {
137                        Database<D> sourceDatabase = context.getSourceDatabase();
138                        
139                        Set<Database<D>> databases = context.getActiveDatabaseSet();
140
141                        ExecutorService executor = context.getExecutor();
142                        
143                        Dialect dialect = context.getDialect();
144                        
145                        Map<SequenceProperties, Long> sequenceMap = new HashMap<SequenceProperties, Long>();
146                        Map<Database<D>, Future<Long>> futureMap = new HashMap<Database<D>, Future<Long>>();
147
148                        for (SequenceProperties sequence: sequences)
149                        {
150                                final String sql = dialect.getNextSequenceValueSQL(sequence);
151                                
152                                logger.debug(sql);
153
154                                for (final Database<D> database: databases)
155                                {
156                                        Callable<Long> task = new Callable<Long>()
157                                        {
158                                                public Long call() throws SQLException
159                                                {
160                                                        Statement statement = context.getConnection(database).createStatement();
161                                                        ResultSet resultSet = statement.executeQuery(sql);
162                                                        
163                                                        resultSet.next();
164                                                        
165                                                        long value = resultSet.getLong(1);
166                                                        
167                                                        statement.close();
168                                                        
169                                                        return value;
170                                                }
171                                        };
172                                        
173                                        futureMap.put(database, executor.submit(task));                         
174                                }
175
176                                try
177                                {
178                                        Long sourceValue = futureMap.get(sourceDatabase).get();
179                                        
180                                        sequenceMap.put(sequence, sourceValue);
181                                        
182                                        for (Database<D> database: databases)
183                                        {
184                                                if (!database.equals(sourceDatabase))
185                                                {
186                                                        Long value = futureMap.get(database).get();
187                                                        
188                                                        if (!value.equals(sourceValue))
189                                                        {
190                                                                throw new SQLException(Messages.getMessage(Messages.SEQUENCE_OUT_OF_SYNC, sequence, database, value, sourceDatabase, sourceValue));
191                                                        }
192                                                }
193                                        }
194                                }
195                                catch (InterruptedException e)
196                                {
197                                        throw SQLExceptionFactory.createSQLException(e);
198                                }
199                                catch (ExecutionException e)
200                                {
201                                        throw SQLExceptionFactory.createSQLException(e.getCause());
202                                }
203                        }
204                        
205                        Connection targetConnection = context.getConnection(context.getTargetDatabase());
206                        Statement targetStatement = targetConnection.createStatement();
207
208                        for (SequenceProperties sequence: sequences)
209                        {
210                                String sql = dialect.getAlterSequenceSQL(sequence, sequenceMap.get(sequence) + 1);
211                                
212                                logger.debug(sql);
213                                
214                                targetStatement.addBatch(sql);
215                        }
216                        
217                        targetStatement.executeBatch();         
218                        targetStatement.close();
219                }
220        }
221        
222        /**
223         * @param <D>
224         * @param context
225         * @throws SQLException
226         */
227        public static <D> void synchronizeIdentityColumns(SynchronizationContext<D> context) throws SQLException
228        {
229                Statement sourceStatement = context.getConnection(context.getSourceDatabase()).createStatement();
230                Statement targetStatement = context.getConnection(context.getTargetDatabase()).createStatement();
231                
232                Dialect dialect = context.getDialect();
233                
234                for (TableProperties table: context.getSourceDatabaseProperties().getTables())
235                {
236                        Collection<String> columns = table.getIdentityColumns();
237                        
238                        if (!columns.isEmpty())
239                        {
240                                String selectSQL = MessageFormat.format("SELECT max({0}) FROM {1}", Strings.join(columns, "), max("), table.getName()); //$NON-NLS-1$ //$NON-NLS-2$
241                                
242                                logger.debug(selectSQL);
243                                
244                                Map<String, Long> map = new HashMap<String, Long>();
245                                
246                                ResultSet resultSet = sourceStatement.executeQuery(selectSQL);
247                                
248                                if (resultSet.next())
249                                {
250                                        int i = 0;
251                                        
252                                        for (String column: columns)
253                                        {
254                                                map.put(column, resultSet.getLong(++i));
255                                        }
256                                }
257                                
258                                resultSet.close();
259                                
260                                if (!map.isEmpty())
261                                {
262                                        for (Map.Entry<String, Long> mapEntry: map.entrySet())
263                                        {
264                                                String alterSQL = dialect.getAlterIdentityColumnSQL(table, table.getColumnProperties(mapEntry.getKey()), mapEntry.getValue() + 1);
265                                                
266                                                if (alterSQL != null)
267                                                {
268                                                        logger.debug(alterSQL);
269                                                        
270                                                        targetStatement.addBatch(alterSQL);
271                                                }
272                                        }
273                                        
274                                        targetStatement.executeBatch();
275                                }
276                        }
277                }
278                
279                sourceStatement.close();
280                targetStatement.close();
281        }
282
283        /**
284         * @param <D>
285         * @param context
286         * @throws SQLException
287         */
288        public static <D> void dropUniqueConstraints(SynchronizationContext<D> context) throws SQLException
289        {
290                Dialect dialect = context.getDialect();
291
292                Connection connection = context.getConnection(context.getTargetDatabase());
293                
294                Statement statement = connection.createStatement();
295                
296                for (TableProperties table: context.getTargetDatabaseProperties().getTables())
297                {
298                        for (UniqueConstraint constraint: table.getUniqueConstraints())
299                        {
300                                String sql = dialect.getDropUniqueConstraintSQL(constraint);
301                                
302                                logger.debug(sql);
303                                
304                                statement.addBatch(sql);
305                        }
306                }
307                
308                statement.executeBatch();
309                statement.close();
310        }
311        
312        /**
313         * @param <D>
314         * @param context
315         * @throws SQLException
316         */
317        public static <D> void restoreUniqueConstraints(SynchronizationContext<D> context) throws SQLException
318        {
319                Dialect dialect = context.getDialect();
320
321                Connection connection = context.getConnection(context.getTargetDatabase());
322                
323                Statement statement = connection.createStatement();
324                
325                for (TableProperties table: context.getSourceDatabaseProperties().getTables())
326                {
327                        // Drop unique constraints on the current table
328                        for (UniqueConstraint constraint: table.getUniqueConstraints())
329                        {
330                                String sql = dialect.getCreateUniqueConstraintSQL(constraint);
331                                
332                                logger.debug(sql);
333                                
334                                statement.addBatch(sql);
335                        }
336                }
337                
338                statement.executeBatch();
339                statement.close();
340        }
341        
342        /**
343         * @param connection
344         */
345        public static void rollback(Connection connection)
346        {
347                try
348                {
349                        connection.rollback();
350                        connection.setAutoCommit(true);
351                }
352                catch (SQLException e)
353                {
354                        logger.warn(e.toString(), e);
355                }
356        }
357        
358        /**
359         * Helper method for {@link java.sql.ResultSet#getObject(int)} with special handling for large objects.
360         * @param resultSet
361         * @param index
362         * @param type
363         * @return the object of the specified type at the specified index from the specified result set
364         * @throws SQLException
365         */
366        public static Object getObject(ResultSet resultSet, int index, int type) throws SQLException
367        {
368                switch (type)
369                {
370                        case Types.BLOB:
371                        {
372                                return resultSet.getBlob(index);
373                        }
374                        case Types.CLOB:
375                        {
376                                return resultSet.getClob(index);
377                        }
378                        default:
379                        {
380                                return resultSet.getObject(index);
381                        }
382                }
383        }
384}