FFI and ODBC connectivity

Krasimir Angelov ka2_mail@yahoo.com
Thu, 6 Jun 2002 06:57:04 -0700 (PDT)


--0-1914213734-1023371824=:11434
Content-Type: text/plain; charset=us-ascii
Content-Disposition: inline

Hello all,

 HaskellDB is a nice, combinator library but there are
two main disadvantages. 

    * HaskellDB uses 'Trex' module which is Hugs
specific. Both GHC and NHC doesn't support 'Trex'.
    * HaskellDB cann't execute stored SQL procedures
and doesn't allow to execute plain SQL statements.

 I use HSQL (see mail attachment) for data access.
HSQL works with ODBC but its user interface isn't ODBC
specific. The module can be rewriten to use native
drivers for specific databases
(MySQL,PostgresSQL,Oracle,Sybase,...). If there is
somebody interested in using of HSQL I will place it
on CVS.

Krasimir Angelov

__________________________________________________
Do You Yahoo!?
Yahoo! - Official partner of 2002 FIFA World Cup
http://fifaworldcup.yahoo.com
--0-1914213734-1023371824=:11434
Content-Type: text/plain; name="HSQL.hsc"
Content-Description: HSQL.hsc
Content-Disposition: inline; filename="HSQL.hsc"

module HSQL
		( SqlBind(..), SqlError(..), SqlType(..), Connection, Statement
		, catchSql			-- :: IO a -> (SqlError -> IO a) -> IO a
		, connect			-- :: String -> String -> String -> IO Connection
		, disconnect		-- :: Connection -> IO ()
		, execute			-- :: Connection -> String -> IO Statement
		, closeStatement	-- :: Statement -> IO ()
		, fetch				-- :: Statement -> IO Bool
		, inTransaction		-- :: Connection -> (Connection -> IO a) -> IO a
		, getFieldValue		-- :: SqlBind a => Statement -> String -> IO a
		, getFieldValueType	-- :: Statement -> String -> (SqlType, Bool)
		, getFieldsTypes	-- :: Statement -> (String, SqlType, Bool)
		, forEachRow		-- :: (Statement -> s -> IO s) -> Statement -> s -> IO s
		, forEachRow'		-- :: (Statement -> IO ()) -> Statement -> IO ()
		, collectRows		-- :: (Statement -> IO s) -> Statement -> IO [s]
		)	where

import Word(Word32, Word16)
import Int(Int32, Int16)
import Foreign
import CString
import IORef
import Monad(when)
import Exception (throwDyn, catchDyn, Exception(..))
import Dynamic

#include <HSQLStructs.h>

type SQLHANDLE = Ptr ()
type HENV = SQLHANDLE
type HDBC = SQLHANDLE
type HSTMT = SQLHANDLE
type HENVRef = ForeignPtr ()

type SQLSMALLINT  = Int16
type SQLUSMALLINT = Word16
type SQLINTEGER   = Int32
type SQLUINTEGER  = Word32
type SQLRETURN	  = SQLSMALLINT
type SQLLEN		  = SQLINTEGER
type SQLULEN	  = SQLINTEGER

foreign import stdcall "sqlext.h SQLAllocEnv" sqlAllocEnv :: Ptr HENV -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLFreeEnv" sqlFreeEnv :: HENV -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLAllocConnect" sqlAllocConnect :: HENV -> Ptr HDBC -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLFreeConnect" sqlFreeConnect:: HDBC -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLConnect" sqlConnect :: HDBC -> CString -> Int -> CString -> Int -> CString -> Int -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLDisconnect" sqlDisconnect :: HDBC -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLAllocStmt" sqlAllocStmt :: HDBC -> Ptr HSTMT -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLFreeStmt" sqlFreeStmt :: HSTMT -> SQLUSMALLINT -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLNumResultCols" sqlNumResultCols :: HSTMT -> Ptr SQLUSMALLINT -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLDescribeCol" sqlDescribeCol :: HSTMT -> SQLUSMALLINT -> CString -> SQLSMALLINT -> Ptr SQLSMALLINT -> Ptr SQLSMALLINT -> Ptr SQLULEN -> Ptr SQLSMALLINT -> Ptr SQLSMALLINT -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLBindCol" sqlBindCol :: HSTMT -> SQLUSMALLINT -> SQLSMALLINT -> Ptr a -> SQLLEN -> Ptr SQLINTEGER -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLFetch" sqlFetch :: HSTMT -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLGetDiagRec" sqlGetDiagRec :: SQLSMALLINT -> SQLHANDLE -> SQLSMALLINT -> CString -> Ptr SQLINTEGER -> CString -> SQLSMALLINT -> Ptr SQLSMALLINT -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLExecDirect" sqlExecDirect :: HSTMT -> CString -> Int -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLSetConnectOption" sqlSetConnectOption :: HDBC -> SQLUSMALLINT -> SQLULEN -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLTransact" sqlTransact :: HENV -> HDBC -> SQLUSMALLINT -> IO SQLRETURN

data Connection
  =  Connection
		{ hDBC :: HDBC
		, environment :: HENVRef
		}
		
type FieldDef = (String, SqlType, Bool, Int)

data Statement
  =  Statement
		{ hSTMT :: HSTMT
		, connection :: Connection
		, fields :: [FieldDef]
		, fetchBuffer :: Ptr ()
		}

data SqlType	
	= SqlChar			Int
	| SqlVarChar		Int
	| SqlLongVarChar	Int
	| SqlDecimal		Int Int
	| SqlNumeric		Int Int
	| SqlSmallInt
	| SqlInteger
	| SqlReal
	| SqlDouble
	| SqlBit
	| SqlTinyInt
	| SqlBigInt
	| SqlBinary			Int
	| SqlVarBinary		Int
	| SqlLongVarBinary  Int
	| SqlDate
	| SqlTime
	| SqlTimeStamp
	deriving (Eq, Show)
	
data SqlError
   = SqlError
		{ seState		:: String
		, seNativeError	:: Int
		, seErrorMsg	:: String
		}
   | SqlNoData
   | SqlInvalidHandle
   | SqlStillExecuting
   | SqlNeedData
   deriving Show
   
-----------------------------------------------------------------------------------------
-- routines for handling exceptions
-----------------------------------------------------------------------------------------

{-# NOINLINE sqlErrorTy #-}
sqlErrorTy = mkAppTy (mkTyCon "SqlError") []

instance Typeable SqlError where
	typeOf x = sqlErrorTy

catchSql :: IO a -> (SqlError -> IO a) -> IO a
catchSql = catchDyn

sqlSuccess :: SQLRETURN -> Bool
sqlSuccess res = 
	(res == (#const SQL_SUCCESS)) || (res == (#const SQL_SUCCESS_WITH_INFO)) || (res == (#const SQL_NO_DATA))

handleSqlResult :: SQLSMALLINT -> SQLHANDLE -> SQLRETURN -> IO ()
handleSqlResult handleType handle res
	| sqlSuccess res = return ()
	| res == (#const SQL_INVALID_HANDLE) = throwDyn SqlInvalidHandle
	| res == (#const SQL_STILL_EXECUTING) = throwDyn SqlStillExecuting
	| res == (#const SQL_NEED_DATA) = throwDyn SqlNeedData
	| res == (#const SQL_ERROR) = do	
		pState  <- mallocBytes 256
		pNative <- malloc
		pMsg    <- mallocBytes 256
		pTextLen <- malloc
		sqlGetDiagRec handleType handle 1 pState pNative pMsg 256 pTextLen
		state <- peekCString pState
		free pState
		native <- peek pNative
		free pNative
		msg <- peekCString pMsg
		free pMsg
		free pTextLen	
		throwDyn (SqlError {seState=state, seNativeError=fromIntegral native, seErrorMsg=msg})
	| otherwise = error (show res)

-----------------------------------------------------------------------------------------
-- keeper of HENV
-----------------------------------------------------------------------------------------

{-# NOINLINE myEnvironment #-}
myEnvironment :: HENVRef
myEnvironment = unsafePerformIO $ do
	(phEnv :: Ptr HENV) <- malloc
	res <- sqlAllocEnv phEnv
	hEnv <- peek phEnv
	free phEnv
	handleSqlResult 0 nullPtr res
	newForeignPtr hEnv (closeEnvironment hEnv)
	where
		closeEnvironment :: HENV -> IO ()
		closeEnvironment hEnv = sqlFreeEnv hEnv >>= handleSqlResult (#const SQL_HANDLE_ENV) hEnv

-----------------------------------------------------------------------------------------
-- Connect/Disconnect
-----------------------------------------------------------------------------------------

connect :: String -> String -> String -> IO Connection
connect server user authentication = withForeignPtr myEnvironment $ \hEnv -> do
	(phDBC :: Ptr HDBC) <- malloc
	res <- sqlAllocConnect hEnv phDBC
	hDBC <- peek phDBC
	free phDBC
	handleSqlResult (#const SQL_HANDLE_ENV) hEnv res
	pServer <- newCString server
	pUser <- newCString user
	pAuthentication <- newCString authentication
	res <- sqlConnect hDBC pServer (length server) pUser (length user) pAuthentication (length authentication)
	free pServer
	free pUser
	free pAuthentication
	handleSqlResult (#const SQL_HANDLE_ENV) hEnv res
	return (Connection {hDBC=hDBC, environment=myEnvironment})
	
disconnect :: Connection -> IO ()
disconnect (Connection {hDBC=hDBC}) = do
	sqlDisconnect hDBC >>= handleSqlResult (#const SQL_HANDLE_DBC) hDBC
	sqlFreeConnect hDBC >>= handleSqlResult (#const SQL_HANDLE_DBC) hDBC
	return ()
	
-----------------------------------------------------------------------------------------
-- queries
-----------------------------------------------------------------------------------------

execute :: Connection -> String -> IO Statement
execute conn@(Connection {hDBC=hDBC}) query = do
	pFIELD <- mallocBytes (#const sizeof(FIELD))	
	res <- sqlAllocStmt hDBC ((#ptr FIELD, hSTMT) pFIELD)
	when (not (sqlSuccess res)) (free pFIELD)
	handleSqlResult (#const SQL_HANDLE_DBC) hDBC res
	hSTMT <- (#peek FIELD, hSTMT) pFIELD
	let handleResult res = do
		when (not (sqlSuccess res)) (free pFIELD)
		handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
	pQuery <- newCString query
	res <- sqlExecDirect hSTMT pQuery (length query)
	free pQuery
	handleResult res
	sqlNumResultCols hSTMT ((#ptr FIELD, fieldsCount) pFIELD) >>= handleResult
	count <- (#peek FIELD, fieldsCount) pFIELD
	(fields, offs) <- createBindState hSTMT pFIELD 0 1 count
	free pFIELD
	buffer <- mallocBytes offs
	let statement = Statement {hSTMT=hSTMT, connection=conn, fields=fields, fetchBuffer=buffer}
	catchSql (bindFields hSTMT buffer 1 fields) (errHandler statement)
	return statement
	where
		errHandler statement err = do
			closeStatement statement
			throwDyn err			
		
		createBindState :: HSTMT -> Ptr a -> Int -> SQLUSMALLINT -> SQLUSMALLINT -> IO ([FieldDef], Int)
		createBindState hSTMT pFIELD offs n count
			| n > count  = return ([], offs)
			| otherwise = do
				res <- sqlDescribeCol hSTMT n ((#ptr FIELD, fieldName) pFIELD) (#const FIELD_NAME_LENGTH) ((#ptr FIELD, NameLength) pFIELD) ((#ptr FIELD, DataType) pFIELD) ((#ptr FIELD, ColumnSize) pFIELD) ((#ptr FIELD, DecimalDigits) pFIELD) ((#ptr FIELD, Nullable) pFIELD)
				handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
				name <- peekCString ((#ptr FIELD, fieldName) pFIELD)
				dataType <- (#peek FIELD, DataType) pFIELD
				columnSize <- (#peek FIELD, ColumnSize) pFIELD
				decimalDigits <- (#peek FIELD, DecimalDigits) pFIELD
				(nullable :: SQLSMALLINT) <- (#peek FIELD, Nullable) pFIELD
				let (sqlType, offs') = mkSqlType dataType columnSize decimalDigits (offs+(#const sizeof(SQLINTEGER)))
				(fields, offs'') <- createBindState hSTMT pFIELD offs' (n+1) count
				return ((name,sqlType,toBool nullable,offs):fields, offs'')
				
		bindFields :: HSTMT -> Ptr () -> SQLUSMALLINT -> [FieldDef] -> IO ()
		bindFields hSTMT fetchBuffer n [] = return ()
		bindFields hSTMT fetchBuffer n ((name,SqlChar	size,		nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral (size+1) * (#const sizeof(SQLCHAR))) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlVarChar size,		nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral (size+1) * (#const sizeof(SQLCHAR))) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlLongVarChar size,	nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral (size+1) * (#const sizeof(SQLCHAR))) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlDecimal size prec,nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_DOUBLE) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLDOUBLE)) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlNumeric size prec,nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_DOUBLE) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLDOUBLE)) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlSmallInt,		nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_SHORT) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLSMALLINT)) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlInteger,			nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_LONG) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLINTEGER)) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlReal,			nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_DOUBLE) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLDOUBLE)) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlDouble,			nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_DOUBLE) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLDOUBLE)) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlBit,				nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_LONG) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLINTEGER)) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlTinyInt,			nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_SHORT) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLSMALLINT)) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlBigInt,			nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_LONG) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLINTEGER)) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlBinary size,		nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral size * (#const sizeof(SQLCHAR))) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlVarBinary size,	nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral size * (#const sizeof(SQLCHAR))) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlLongVarBinary size,nullable,offs):fields)= do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral size * (#const sizeof(SQLCHAR))) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlDate,			nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_TYPE_DATE) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQL_DATE_STRUCT)) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlTime,			nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_TYPE_TIME) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQL_TIME_STRUCT)) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		bindFields hSTMT fetchBuffer n ((name,SqlTimeStamp,		nullable,offs):fields) = do
			let buffer = fetchBuffer `plusPtr` offs
			res <- sqlBindCol hSTMT n (#const SQL_C_TYPE_TIMESTAMP) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQL_TIMESTAMP_STRUCT)) (castPtr buffer)
			handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
			bindFields hSTMT fetchBuffer (n+1) fields
		
			
				
		mkSqlType :: SQLSMALLINT -> SQLULEN -> SQLSMALLINT -> Int -> (SqlType, Int)
		mkSqlType (#const SQL_CHAR)			size _		offs	= (SqlChar (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1))
		mkSqlType (#const SQL_VARCHAR)		size _		offs	= (SqlVarChar (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1))
		mkSqlType (#const SQL_LONGVARCHAR)	size _		offs	= (SqlLongVarChar (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1))
		mkSqlType (#const SQL_DECIMAL)		size prec	offs	= (SqlDecimal (fromIntegral size) (fromIntegral prec), offs + (#const sizeof(SQLDOUBLE)))
		mkSqlType (#const SQL_NUMERIC)		size prec	offs	= (SqlNumeric (fromIntegral size) (fromIntegral prec), offs + (#const sizeof(SQLDOUBLE)))
		mkSqlType (#const SQL_SMALLINT)		_	 _		offs	= (SqlSmallInt, offs + (#const sizeof(SQLSMALLINT)))
		mkSqlType (#const SQL_INTEGER)		_    _		offs	= (SqlInteger, offs + (#const sizeof(SQLINTEGER)))
		mkSqlType (#const SQL_REAL)			_	 _		offs	= (SqlReal, offs + (#const sizeof(SQLDOUBLE)))
		mkSqlType (#const SQL_DOUBLE)		_	 _		offs	= (SqlDouble, offs + (#const sizeof(SQLDOUBLE)))
		mkSqlType (#const SQL_BIT)			_	 _		offs	= (SqlBit, offs + (#const sizeof(SQLINTEGER)))
		mkSqlType (#const SQL_TINYINT)		_	 _		offs	= (SqlTinyInt, offs + (#const sizeof(SQLSMALLINT)))
		mkSqlType (#const SQL_BIGINT)		_	 _		offs	= (SqlBigInt, offs + (#const sizeof(SQLINTEGER)))
		mkSqlType (#const SQL_BINARY)		size _		offs	= (SqlBinary (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1))
		mkSqlType (#const SQL_VARBINARY)	size _		offs	= (SqlVarBinary (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1))
		mkSqlType (#const SQL_LONGVARBINARY)size _		offs	= (SqlLongVarBinary (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1))
		mkSqlType (#const SQL_DATE)			_	 _		offs	= (SqlDate, offs + (#const sizeof(SQL_DATE_STRUCT)))
		mkSqlType (#const SQL_TIME)			_	 _		offs	= (SqlTime, offs + (#const sizeof(SQL_TIME_STRUCT)))
		mkSqlType (#const SQL_TIMESTAMP)	_	 _		offs	= (SqlTimeStamp, offs + (#const sizeof(SQL_TIMESTAMP_STRUCT)))


{-# NOINLINE fetch #-}
fetch :: Statement -> IO Bool
fetch stmt = do
	res <- sqlFetch (hSTMT stmt)
	handleSqlResult (#const SQL_HANDLE_STMT) (hSTMT stmt) res
	return (res /= (#const SQL_NO_DATA))
	

closeStatement :: Statement -> IO ()
closeStatement stmt = do
	sqlFreeStmt (hSTMT stmt) 0 >>= handleSqlResult (#const SQL_HANDLE_STMT) (hSTMT stmt)
	free (fetchBuffer stmt)

-----------------------------------------------------------------------------------------
-- transactions
-----------------------------------------------------------------------------------------

inTransaction :: Connection -> (Connection -> IO a) -> IO a
inTransaction conn@(Connection {hDBC=hDBC, environment=envRef}) action = withForeignPtr envRef $ \hEnv -> do
	sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_OFF)
	r <- catchSql (action conn) (\err -> do
			sqlTransact hEnv hDBC (#const SQL_ROLLBACK)
			sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_ON)
			throwDyn err)
	sqlTransact hEnv hDBC (#const SQL_COMMIT)
	sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_ON)
	return r
-----------------------------------------------------------------------------------------
-- binding
-----------------------------------------------------------------------------------------

class SqlBind a where
	getValue :: SqlType -> Ptr () -> IO a

instance SqlBind Int where
	getValue SqlInteger ptr  = peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))))
	getValue SqlSmallInt ptr = do
		(n :: Int16) <- peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))))
		return (fromIntegral n)
	
instance SqlBind String where
	getValue (SqlChar size) ptr = do
		len <- peek (castPtr ptr)
		peekCStringLen (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))), len)
	getValue (SqlVarChar size) ptr = do
		len <- peek (castPtr ptr)
		peekCStringLen (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))), len)
	getValue (SqlLongVarChar size) ptr = do
		len <- peek (castPtr ptr)
		peekCStringLen (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))), len)
		
instance SqlBind Double where
	getValue (SqlDecimal size prec) ptr = peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))))
	getValue (SqlNumeric size prec) ptr = peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))))
	getValue SqlDouble ptr = peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))))
	getValue SqlReal ptr = peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))))
	

getFieldValue :: SqlBind a => Statement -> String -> IO a
getFieldValue stmt name = getValue sqlType ((fetchBuffer stmt) `plusPtr` offs)
	where
		(_,sqlType,nullable,offs) = findField name (fields stmt)

getFieldValueType :: Statement -> String -> (SqlType, Bool)
getFieldValueType stmt name = (sqlType, nullable)
	where
		(_,sqlType,nullable,offs) = findField name (fields stmt)
		
getFieldsTypes :: Statement -> [(String, SqlType, Bool)]
getFieldsTypes stmt = map (\(name,sqlType,nullable,_) -> (name,sqlType,nullable)) (fields stmt)

findField :: String -> [FieldDef] -> FieldDef
findField name [] = error (name ++ "??")
findField name (fieldDef@(name',_,_,_):fields)
	| name == name' = fieldDef
	| otherwise		= findField name fields

-----------------------------------------------------------------------------------------
-- helpers
-----------------------------------------------------------------------------------------

forEachRow :: (Statement -> s -> IO s) -> Statement -> s -> IO s
forEachRow f stmt s = do
	success <- fetch stmt
	if success then f stmt s >>= forEachRow f stmt else closeStatement stmt >> return s
	
forEachRow' :: (Statement -> IO ()) -> Statement -> IO ()
forEachRow' f stmt = do
	success <- fetch stmt
	if success then f stmt >> forEachRow' f stmt else closeStatement stmt
	
collectRows :: (Statement -> IO a) -> Statement -> IO [a]
collectRows f stmt = loop
	where 
		loop = do
			success <- fetch stmt
			if success
				then do
					x <- f stmt
					xs <- loop
					return (x:xs)
				else closeStatement stmt >> return []
--0-1914213734-1023371824=:11434--