package com.rapid.j2ee.framework.orm.mybatis.pagination.intercept;

import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Properties;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.parameter.DefaultParameterHandler;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.datasource.DataSourceUtils;
import org.springframework.util.Assert;

import com.itextpdf.text.log.SysoLogger;
import com.rapid.j2ee.framework.core.exception.ExceptionUtils;
import com.rapid.j2ee.framework.core.pagination.PageConstants;
import com.rapid.j2ee.framework.core.pagination.PageOutputContainer;
import com.rapid.j2ee.framework.core.reflect.InvokeUtils;
import com.rapid.j2ee.framework.core.spring.SpringApplicationContextHolder;
import com.rapid.j2ee.framework.core.utils.CollectionsUtil;
import com.rapid.j2ee.framework.core.utils.StringUtils;
import com.rapid.j2ee.framework.core.utils.TypeChecker;
import com.rapid.j2ee.framework.format.Formatter;
import com.rapid.j2ee.framework.format.SQLFormatter;

import org.apache.ibatis.mapping.MappedStatement.Builder;

@Intercepts(@Signature(type = Executor.class, method = "query", args = {
		MappedStatement.class, Object.class, RowBounds.class,
		ResultHandler.class }))
public class PaginationInterceptor implements Interceptor {

	public Object intercept(Invocation invocation) throws Throwable {
		System.out.println("PaginationInterceptor called............");
		Logger
				.info("PaginationInterceptor  intercept(Invocation invocation) called........");

		if (TypeChecker.isEmpty(sqlCountConverters)) {
			SpringApplicationContextHolder.inject(this);
		}

		MappedStatement mappedStatement = (MappedStatement) invocation
				.getArgs()[0];

		BoundSql boundSql = mappedStatement
				.getBoundSql(invocation.getArgs()[1]);

		if (!this.isAutoPaginationRequired(boundSql)) {
			return invocation.proceed();
		}

		String paginationSql = this.getPaginationSql(mappedStatement
				.getConfiguration(), (Map) boundSql.getParameterObject(),
				boundSql.getSql());

		this.setTotalRecord(boundSql, mappedStatement);

		InvokeUtils.setField(boundSql, "sql", paginationSql);

		Logger
				.info("\n-------------------Pagination Begin SQL ----------------------------------------------------------------------");

		Logger.info("");
		Logger.info("SQL:" + SQL_Formatter.format(paginationSql));
		Logger.info("");
		Logger.info("SQL Parameter:\n\t" + boundSql.getParameterObject());
		Logger.info("");

		Logger
				.info("-------------------Pagination End SQL----------------------------------------------------------------------\n");

		MappedStatement newMappedStatement = newMappedStatement(
				mappedStatement, new BoundSqlSource(boundSql));

		invocation.getArgs()[0] = newMappedStatement;

		return invocation.proceed();

	}

	private static class BoundSqlSource implements SqlSource {

		private BoundSql boundSql;

		public BoundSqlSource(BoundSql boundSql) {
			this.boundSql = boundSql;
		}

		public BoundSql getBoundSql(Object parameterObject) {
			return boundSql;
		}
	}

	static boolean isAutoPaginationRequired(BoundSql boundSql) {

		Object parameters = boundSql.getParameterObject();

		if (TypeChecker.isEmpty(boundSql.getSql())) {
			return false;
		}

		if (!boundSql.getSql().trim().toLowerCase().startsWith("select ")
				&& !boundSql.getSql().trim().toLowerCase().startsWith(
						"select\t")) {
			return false;
		}

		if (!(parameters instanceof Map)) {
			return false;
		}

		Map parameter = (Map) parameters;

		return parameter.containsKey(PageConstants.Start_Row_Name)
				&& parameter.containsKey(PageConstants.End_Row_Name);

	}

	/**
	 * 给当前的参数对象page设置总记录数
	 * 
	 * @param page
	 *            Mapper映射语句对应的参数对象
	 * @param mappedStatement
	 *            Mapper映射语句
	 * @param connection
	 *            当前的数据库连接
	 */
	private void setTotalRecord(BoundSql boundSql,
			MappedStatement mappedStatement) {

		Map parameters = (Map) boundSql.getParameterObject();

		String countSql = this.getCountSql(mappedStatement.getId(), boundSql
				.getSql(), parameters);

		Logger
				.info("\n-------------------Mybaties Pagination Count Begin SQL ----------------------------------------------------------------------");

		Logger.info("");
		Logger.info("SQL:" + SQL_Formatter.format(boundSql.getSql()));
		Logger.info("");
		Logger.info("SQL Parameter:\n\t" + boundSql.getParameterObject());
		Logger.info("");

		Logger
				.info("-------------------Mybaties Pagination Count End SQL ----------------------------------------------------------------------\n");

		InvokeUtils.setField(boundSql, "sql", countSql);

		ParameterHandler parameterHandler = new DefaultParameterHandler(
				mappedStatement, boundSql.getParameterObject(), boundSql);

		PreparedStatement pstmt = null;

		ResultSet rs = null;

		Connection con = null;

		try {

			con = mappedStatement.getConfiguration().getEnvironment()
					.getDataSource().getConnection();

			pstmt = con.prepareStatement(countSql);

			parameterHandler.setParameters(pstmt);

			rs = pstmt.executeQuery();

			if (rs.next()) {

				int totalRecord = rs.getInt(1);

				parameters.put(PageOutputContainer.OUTPUT_PARAMETER_KEY
						+ "_TotalCounts", totalRecord);

				Logger
						.info("PageOutputContainer.OUTPUT_PARAMETER_KEY_TotalCounts ========= "
								+ totalRecord);

			}

		} catch (Throwable e) {

			try {
				mappedStatement.getConfiguration().getEnvironment()
						.getDataSource().getConnection().close();
			} catch (Exception ex) {

			}

		} finally {
			try {
				if (rs != null)
					rs.close();
				if (pstmt != null)
					pstmt.close();

			} catch (Exception e) {

			}
			try {

				if (con != null) {

					con.close();

				}
			} catch (Exception e) {

			}
		}
	}

	/**
	 * 根据原Sql语句获取对应的查询总记录数的Sql语句
	 * 
	 * @param sql
	 * @return
	 */
	private String getCountSql(String id, String sql, Map parameter) {

		SqlCountConverter sqlCountConverter = CollectionsUtil.findOne(
				this.sqlCountConverters, "mapperId", id);

		if (TypeChecker.isNull(sqlCountConverter)) {
			sqlCountConverter = CollectionsUtil.findOne(
					this.sqlCountConverters, "mapperId", "Default");
		}

		return sqlCountConverter.convert(sql, parameter);
	}

	public Object plugin(Object target) {

		return Plugin.wrap(target, this);
	}

	public void setProperties(Properties properties) {

	}

	/**
	 * 根据page对象获取对应的分页查询Sql语句，这里只做了两种数据库类型，Mysql和Oracle 其它的数据库都 没有进行分页
	 * 
	 * @param page
	 *            分页对象
	 * @param sql
	 *            原sql语句
	 * @return
	 * 
	 * 
	 */
	private String getPaginationSql(Configuration configuration, Map parameter,
			String sql) {

		StringBuffer sqlBuffer = new StringBuffer(sql);

		String databaseType = configuration.getVariables().getProperty(
				"dialect");

		Assert.hasLength(databaseType,
				"Please provide dialect in mybatis configure!");

		Logger.info("DatabaseType =====" + databaseType);

		if ("mysql".equalsIgnoreCase(databaseType)) {
			return getMysqlPageSql(parameter, sqlBuffer);
		} else if ("oracle".equalsIgnoreCase(databaseType)) {
			return getOraclePageSql(parameter, sqlBuffer);
		}

		throw new IllegalArgumentException(
				"Sorry, no provider to handle page sql! " + databaseType);
	}

	/**
	 * 获取Mysql数据库的分页查询语句
	 * 
	 * @param page
	 *            分页对象
	 * @param sqlBuffer
	 *            包含原sql语句的StringBuffer对象
	 * @return Mysql数据库分页语句
	 */
	private String getMysqlPageSql(Map parameter, StringBuffer sqlBuffer) {

		Integer startRow = (Integer) parameter
				.get(PageConstants.Start_Row_Name);
		Integer endRow = (Integer) parameter.get(PageConstants.End_Row_Name);

		sqlBuffer.append(" LIMIT ").append(startRow - 1).append(",").append(
				(endRow - startRow + 1));
		return sqlBuffer.toString();
	}

	/**
	 * 获取Oracle数据库的分页查询语句
	 * 
	 * @param page
	 *            分页对象
	 * @param sqlBuffer
	 *            包含原sql语句的StringBuffer对象
	 * @return Oracle数据库的分页查询语句
	 */
	private String getOraclePageSql(Map parameter, StringBuffer sqlBuffer) {

		Integer startRow = (Integer) parameter
				.get(PageConstants.Start_Row_Name);
		Integer endRow = (Integer) parameter.get(PageConstants.End_Row_Name);

		sqlBuffer.insert(0, "select u.*, rownum rowIndex from (").append(
				") u where rownum <= ").append(endRow);
		sqlBuffer.insert(0, "select * from (").append(") where rowIndex >= ")
				.append(startRow);

		return sqlBuffer.toString();
	}

	/**
	 * 复制MappedStatement对象
	 */
	private MappedStatement newMappedStatement(MappedStatement ms,
			SqlSource boundSql) {

		Builder builder = new Builder(ms.getConfiguration(), ms.getId(),
				boundSql, ms.getSqlCommandType());

		builder.resource(ms.getResource());
		builder.fetchSize(ms.getFetchSize());
		builder.statementType(ms.getStatementType());
		builder.keyGenerator(ms.getKeyGenerator());
		builder.timeout(ms.getTimeout());
		builder.parameterMap(ms.getParameterMap());
		builder.resultMaps(ms.getResultMaps());
		builder.resultSetType(ms.getResultSetType());
		builder.cache(ms.getCache());
		builder.flushCacheRequired(ms.isFlushCacheRequired());
		builder.useCache(ms.isUseCache());

		return builder.build();
	}

	@Autowired(required = false)
	private List<SqlCountConverter> sqlCountConverters = new ArrayList<SqlCountConverter>();

	private static final Log Logger = LogFactory
			.getLog(PaginationInterceptor.class);

	private static Formatter SQL_Formatter = new SQLFormatter();

}
