FFI and ODBC connectivity
Krasimir Angelov
ka2_mail@yahoo.com
Thu, 6 Jun 2002 06:57:04 -0700 (PDT)
--0-1914213734-1023371824=:11434
Content-Type: text/plain; charset=us-ascii
Content-Disposition: inline
Hello all,
HaskellDB is a nice, combinator library but there are
two main disadvantages.
* HaskellDB uses 'Trex' module which is Hugs
specific. Both GHC and NHC doesn't support 'Trex'.
* HaskellDB cann't execute stored SQL procedures
and doesn't allow to execute plain SQL statements.
I use HSQL (see mail attachment) for data access.
HSQL works with ODBC but its user interface isn't ODBC
specific. The module can be rewriten to use native
drivers for specific databases
(MySQL,PostgresSQL,Oracle,Sybase,...). If there is
somebody interested in using of HSQL I will place it
on CVS.
Krasimir Angelov
__________________________________________________
Do You Yahoo!?
Yahoo! - Official partner of 2002 FIFA World Cup
http://fifaworldcup.yahoo.com
--0-1914213734-1023371824=:11434
Content-Type: text/plain; name="HSQL.hsc"
Content-Description: HSQL.hsc
Content-Disposition: inline; filename="HSQL.hsc"
module HSQL
( SqlBind(..), SqlError(..), SqlType(..), Connection, Statement
, catchSql -- :: IO a -> (SqlError -> IO a) -> IO a
, connect -- :: String -> String -> String -> IO Connection
, disconnect -- :: Connection -> IO ()
, execute -- :: Connection -> String -> IO Statement
, closeStatement -- :: Statement -> IO ()
, fetch -- :: Statement -> IO Bool
, inTransaction -- :: Connection -> (Connection -> IO a) -> IO a
, getFieldValue -- :: SqlBind a => Statement -> String -> IO a
, getFieldValueType -- :: Statement -> String -> (SqlType, Bool)
, getFieldsTypes -- :: Statement -> (String, SqlType, Bool)
, forEachRow -- :: (Statement -> s -> IO s) -> Statement -> s -> IO s
, forEachRow' -- :: (Statement -> IO ()) -> Statement -> IO ()
, collectRows -- :: (Statement -> IO s) -> Statement -> IO [s]
) where
import Word(Word32, Word16)
import Int(Int32, Int16)
import Foreign
import CString
import IORef
import Monad(when)
import Exception (throwDyn, catchDyn, Exception(..))
import Dynamic
#include <HSQLStructs.h>
type SQLHANDLE = Ptr ()
type HENV = SQLHANDLE
type HDBC = SQLHANDLE
type HSTMT = SQLHANDLE
type HENVRef = ForeignPtr ()
type SQLSMALLINT = Int16
type SQLUSMALLINT = Word16
type SQLINTEGER = Int32
type SQLUINTEGER = Word32
type SQLRETURN = SQLSMALLINT
type SQLLEN = SQLINTEGER
type SQLULEN = SQLINTEGER
foreign import stdcall "sqlext.h SQLAllocEnv" sqlAllocEnv :: Ptr HENV -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLFreeEnv" sqlFreeEnv :: HENV -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLAllocConnect" sqlAllocConnect :: HENV -> Ptr HDBC -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLFreeConnect" sqlFreeConnect:: HDBC -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLConnect" sqlConnect :: HDBC -> CString -> Int -> CString -> Int -> CString -> Int -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLDisconnect" sqlDisconnect :: HDBC -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLAllocStmt" sqlAllocStmt :: HDBC -> Ptr HSTMT -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLFreeStmt" sqlFreeStmt :: HSTMT -> SQLUSMALLINT -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLNumResultCols" sqlNumResultCols :: HSTMT -> Ptr SQLUSMALLINT -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLDescribeCol" sqlDescribeCol :: HSTMT -> SQLUSMALLINT -> CString -> SQLSMALLINT -> Ptr SQLSMALLINT -> Ptr SQLSMALLINT -> Ptr SQLULEN -> Ptr SQLSMALLINT -> Ptr SQLSMALLINT -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLBindCol" sqlBindCol :: HSTMT -> SQLUSMALLINT -> SQLSMALLINT -> Ptr a -> SQLLEN -> Ptr SQLINTEGER -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLFetch" sqlFetch :: HSTMT -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLGetDiagRec" sqlGetDiagRec :: SQLSMALLINT -> SQLHANDLE -> SQLSMALLINT -> CString -> Ptr SQLINTEGER -> CString -> SQLSMALLINT -> Ptr SQLSMALLINT -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLExecDirect" sqlExecDirect :: HSTMT -> CString -> Int -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLSetConnectOption" sqlSetConnectOption :: HDBC -> SQLUSMALLINT -> SQLULEN -> IO SQLRETURN
foreign import stdcall "sqlext.h SQLTransact" sqlTransact :: HENV -> HDBC -> SQLUSMALLINT -> IO SQLRETURN
data Connection
= Connection
{ hDBC :: HDBC
, environment :: HENVRef
}
type FieldDef = (String, SqlType, Bool, Int)
data Statement
= Statement
{ hSTMT :: HSTMT
, connection :: Connection
, fields :: [FieldDef]
, fetchBuffer :: Ptr ()
}
data SqlType
= SqlChar Int
| SqlVarChar Int
| SqlLongVarChar Int
| SqlDecimal Int Int
| SqlNumeric Int Int
| SqlSmallInt
| SqlInteger
| SqlReal
| SqlDouble
| SqlBit
| SqlTinyInt
| SqlBigInt
| SqlBinary Int
| SqlVarBinary Int
| SqlLongVarBinary Int
| SqlDate
| SqlTime
| SqlTimeStamp
deriving (Eq, Show)
data SqlError
= SqlError
{ seState :: String
, seNativeError :: Int
, seErrorMsg :: String
}
| SqlNoData
| SqlInvalidHandle
| SqlStillExecuting
| SqlNeedData
deriving Show
-----------------------------------------------------------------------------------------
-- routines for handling exceptions
-----------------------------------------------------------------------------------------
{-# NOINLINE sqlErrorTy #-}
sqlErrorTy = mkAppTy (mkTyCon "SqlError") []
instance Typeable SqlError where
typeOf x = sqlErrorTy
catchSql :: IO a -> (SqlError -> IO a) -> IO a
catchSql = catchDyn
sqlSuccess :: SQLRETURN -> Bool
sqlSuccess res =
(res == (#const SQL_SUCCESS)) || (res == (#const SQL_SUCCESS_WITH_INFO)) || (res == (#const SQL_NO_DATA))
handleSqlResult :: SQLSMALLINT -> SQLHANDLE -> SQLRETURN -> IO ()
handleSqlResult handleType handle res
| sqlSuccess res = return ()
| res == (#const SQL_INVALID_HANDLE) = throwDyn SqlInvalidHandle
| res == (#const SQL_STILL_EXECUTING) = throwDyn SqlStillExecuting
| res == (#const SQL_NEED_DATA) = throwDyn SqlNeedData
| res == (#const SQL_ERROR) = do
pState <- mallocBytes 256
pNative <- malloc
pMsg <- mallocBytes 256
pTextLen <- malloc
sqlGetDiagRec handleType handle 1 pState pNative pMsg 256 pTextLen
state <- peekCString pState
free pState
native <- peek pNative
free pNative
msg <- peekCString pMsg
free pMsg
free pTextLen
throwDyn (SqlError {seState=state, seNativeError=fromIntegral native, seErrorMsg=msg})
| otherwise = error (show res)
-----------------------------------------------------------------------------------------
-- keeper of HENV
-----------------------------------------------------------------------------------------
{-# NOINLINE myEnvironment #-}
myEnvironment :: HENVRef
myEnvironment = unsafePerformIO $ do
(phEnv :: Ptr HENV) <- malloc
res <- sqlAllocEnv phEnv
hEnv <- peek phEnv
free phEnv
handleSqlResult 0 nullPtr res
newForeignPtr hEnv (closeEnvironment hEnv)
where
closeEnvironment :: HENV -> IO ()
closeEnvironment hEnv = sqlFreeEnv hEnv >>= handleSqlResult (#const SQL_HANDLE_ENV) hEnv
-----------------------------------------------------------------------------------------
-- Connect/Disconnect
-----------------------------------------------------------------------------------------
connect :: String -> String -> String -> IO Connection
connect server user authentication = withForeignPtr myEnvironment $ \hEnv -> do
(phDBC :: Ptr HDBC) <- malloc
res <- sqlAllocConnect hEnv phDBC
hDBC <- peek phDBC
free phDBC
handleSqlResult (#const SQL_HANDLE_ENV) hEnv res
pServer <- newCString server
pUser <- newCString user
pAuthentication <- newCString authentication
res <- sqlConnect hDBC pServer (length server) pUser (length user) pAuthentication (length authentication)
free pServer
free pUser
free pAuthentication
handleSqlResult (#const SQL_HANDLE_ENV) hEnv res
return (Connection {hDBC=hDBC, environment=myEnvironment})
disconnect :: Connection -> IO ()
disconnect (Connection {hDBC=hDBC}) = do
sqlDisconnect hDBC >>= handleSqlResult (#const SQL_HANDLE_DBC) hDBC
sqlFreeConnect hDBC >>= handleSqlResult (#const SQL_HANDLE_DBC) hDBC
return ()
-----------------------------------------------------------------------------------------
-- queries
-----------------------------------------------------------------------------------------
execute :: Connection -> String -> IO Statement
execute conn@(Connection {hDBC=hDBC}) query = do
pFIELD <- mallocBytes (#const sizeof(FIELD))
res <- sqlAllocStmt hDBC ((#ptr FIELD, hSTMT) pFIELD)
when (not (sqlSuccess res)) (free pFIELD)
handleSqlResult (#const SQL_HANDLE_DBC) hDBC res
hSTMT <- (#peek FIELD, hSTMT) pFIELD
let handleResult res = do
when (not (sqlSuccess res)) (free pFIELD)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
pQuery <- newCString query
res <- sqlExecDirect hSTMT pQuery (length query)
free pQuery
handleResult res
sqlNumResultCols hSTMT ((#ptr FIELD, fieldsCount) pFIELD) >>= handleResult
count <- (#peek FIELD, fieldsCount) pFIELD
(fields, offs) <- createBindState hSTMT pFIELD 0 1 count
free pFIELD
buffer <- mallocBytes offs
let statement = Statement {hSTMT=hSTMT, connection=conn, fields=fields, fetchBuffer=buffer}
catchSql (bindFields hSTMT buffer 1 fields) (errHandler statement)
return statement
where
errHandler statement err = do
closeStatement statement
throwDyn err
createBindState :: HSTMT -> Ptr a -> Int -> SQLUSMALLINT -> SQLUSMALLINT -> IO ([FieldDef], Int)
createBindState hSTMT pFIELD offs n count
| n > count = return ([], offs)
| otherwise = do
res <- sqlDescribeCol hSTMT n ((#ptr FIELD, fieldName) pFIELD) (#const FIELD_NAME_LENGTH) ((#ptr FIELD, NameLength) pFIELD) ((#ptr FIELD, DataType) pFIELD) ((#ptr FIELD, ColumnSize) pFIELD) ((#ptr FIELD, DecimalDigits) pFIELD) ((#ptr FIELD, Nullable) pFIELD)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
name <- peekCString ((#ptr FIELD, fieldName) pFIELD)
dataType <- (#peek FIELD, DataType) pFIELD
columnSize <- (#peek FIELD, ColumnSize) pFIELD
decimalDigits <- (#peek FIELD, DecimalDigits) pFIELD
(nullable :: SQLSMALLINT) <- (#peek FIELD, Nullable) pFIELD
let (sqlType, offs') = mkSqlType dataType columnSize decimalDigits (offs+(#const sizeof(SQLINTEGER)))
(fields, offs'') <- createBindState hSTMT pFIELD offs' (n+1) count
return ((name,sqlType,toBool nullable,offs):fields, offs'')
bindFields :: HSTMT -> Ptr () -> SQLUSMALLINT -> [FieldDef] -> IO ()
bindFields hSTMT fetchBuffer n [] = return ()
bindFields hSTMT fetchBuffer n ((name,SqlChar size, nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral (size+1) * (#const sizeof(SQLCHAR))) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlVarChar size, nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral (size+1) * (#const sizeof(SQLCHAR))) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlLongVarChar size, nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral (size+1) * (#const sizeof(SQLCHAR))) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlDecimal size prec,nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_DOUBLE) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLDOUBLE)) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlNumeric size prec,nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_DOUBLE) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLDOUBLE)) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlSmallInt, nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_SHORT) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLSMALLINT)) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlInteger, nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_LONG) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLINTEGER)) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlReal, nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_DOUBLE) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLDOUBLE)) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlDouble, nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_DOUBLE) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLDOUBLE)) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlBit, nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_LONG) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLINTEGER)) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlTinyInt, nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_SHORT) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLSMALLINT)) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlBigInt, nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_LONG) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQLINTEGER)) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlBinary size, nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral size * (#const sizeof(SQLCHAR))) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlVarBinary size, nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral size * (#const sizeof(SQLCHAR))) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlLongVarBinary size,nullable,offs):fields)= do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_CHAR) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (fromIntegral size * (#const sizeof(SQLCHAR))) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlDate, nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_TYPE_DATE) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQL_DATE_STRUCT)) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlTime, nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_TYPE_TIME) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQL_TIME_STRUCT)) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
bindFields hSTMT fetchBuffer n ((name,SqlTimeStamp, nullable,offs):fields) = do
let buffer = fetchBuffer `plusPtr` offs
res <- sqlBindCol hSTMT n (#const SQL_C_TYPE_TIMESTAMP) ((castPtr buffer) `plusPtr` (#const sizeof(SQLINTEGER))) (#const sizeof(SQL_TIMESTAMP_STRUCT)) (castPtr buffer)
handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
bindFields hSTMT fetchBuffer (n+1) fields
mkSqlType :: SQLSMALLINT -> SQLULEN -> SQLSMALLINT -> Int -> (SqlType, Int)
mkSqlType (#const SQL_CHAR) size _ offs = (SqlChar (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1))
mkSqlType (#const SQL_VARCHAR) size _ offs = (SqlVarChar (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1))
mkSqlType (#const SQL_LONGVARCHAR) size _ offs = (SqlLongVarChar (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1))
mkSqlType (#const SQL_DECIMAL) size prec offs = (SqlDecimal (fromIntegral size) (fromIntegral prec), offs + (#const sizeof(SQLDOUBLE)))
mkSqlType (#const SQL_NUMERIC) size prec offs = (SqlNumeric (fromIntegral size) (fromIntegral prec), offs + (#const sizeof(SQLDOUBLE)))
mkSqlType (#const SQL_SMALLINT) _ _ offs = (SqlSmallInt, offs + (#const sizeof(SQLSMALLINT)))
mkSqlType (#const SQL_INTEGER) _ _ offs = (SqlInteger, offs + (#const sizeof(SQLINTEGER)))
mkSqlType (#const SQL_REAL) _ _ offs = (SqlReal, offs + (#const sizeof(SQLDOUBLE)))
mkSqlType (#const SQL_DOUBLE) _ _ offs = (SqlDouble, offs + (#const sizeof(SQLDOUBLE)))
mkSqlType (#const SQL_BIT) _ _ offs = (SqlBit, offs + (#const sizeof(SQLINTEGER)))
mkSqlType (#const SQL_TINYINT) _ _ offs = (SqlTinyInt, offs + (#const sizeof(SQLSMALLINT)))
mkSqlType (#const SQL_BIGINT) _ _ offs = (SqlBigInt, offs + (#const sizeof(SQLINTEGER)))
mkSqlType (#const SQL_BINARY) size _ offs = (SqlBinary (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1))
mkSqlType (#const SQL_VARBINARY) size _ offs = (SqlVarBinary (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1))
mkSqlType (#const SQL_LONGVARBINARY)size _ offs = (SqlLongVarBinary (fromIntegral size), offs + (#const sizeof(SQLCHAR))*(fromIntegral size+1))
mkSqlType (#const SQL_DATE) _ _ offs = (SqlDate, offs + (#const sizeof(SQL_DATE_STRUCT)))
mkSqlType (#const SQL_TIME) _ _ offs = (SqlTime, offs + (#const sizeof(SQL_TIME_STRUCT)))
mkSqlType (#const SQL_TIMESTAMP) _ _ offs = (SqlTimeStamp, offs + (#const sizeof(SQL_TIMESTAMP_STRUCT)))
{-# NOINLINE fetch #-}
fetch :: Statement -> IO Bool
fetch stmt = do
res <- sqlFetch (hSTMT stmt)
handleSqlResult (#const SQL_HANDLE_STMT) (hSTMT stmt) res
return (res /= (#const SQL_NO_DATA))
closeStatement :: Statement -> IO ()
closeStatement stmt = do
sqlFreeStmt (hSTMT stmt) 0 >>= handleSqlResult (#const SQL_HANDLE_STMT) (hSTMT stmt)
free (fetchBuffer stmt)
-----------------------------------------------------------------------------------------
-- transactions
-----------------------------------------------------------------------------------------
inTransaction :: Connection -> (Connection -> IO a) -> IO a
inTransaction conn@(Connection {hDBC=hDBC, environment=envRef}) action = withForeignPtr envRef $ \hEnv -> do
sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_OFF)
r <- catchSql (action conn) (\err -> do
sqlTransact hEnv hDBC (#const SQL_ROLLBACK)
sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_ON)
throwDyn err)
sqlTransact hEnv hDBC (#const SQL_COMMIT)
sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_ON)
return r
-----------------------------------------------------------------------------------------
-- binding
-----------------------------------------------------------------------------------------
class SqlBind a where
getValue :: SqlType -> Ptr () -> IO a
instance SqlBind Int where
getValue SqlInteger ptr = peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))))
getValue SqlSmallInt ptr = do
(n :: Int16) <- peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))))
return (fromIntegral n)
instance SqlBind String where
getValue (SqlChar size) ptr = do
len <- peek (castPtr ptr)
peekCStringLen (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))), len)
getValue (SqlVarChar size) ptr = do
len <- peek (castPtr ptr)
peekCStringLen (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))), len)
getValue (SqlLongVarChar size) ptr = do
len <- peek (castPtr ptr)
peekCStringLen (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))), len)
instance SqlBind Double where
getValue (SqlDecimal size prec) ptr = peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))))
getValue (SqlNumeric size prec) ptr = peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))))
getValue SqlDouble ptr = peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))))
getValue SqlReal ptr = peek (castPtr (ptr `plusPtr` (#const sizeof(SQLINTEGER))))
getFieldValue :: SqlBind a => Statement -> String -> IO a
getFieldValue stmt name = getValue sqlType ((fetchBuffer stmt) `plusPtr` offs)
where
(_,sqlType,nullable,offs) = findField name (fields stmt)
getFieldValueType :: Statement -> String -> (SqlType, Bool)
getFieldValueType stmt name = (sqlType, nullable)
where
(_,sqlType,nullable,offs) = findField name (fields stmt)
getFieldsTypes :: Statement -> [(String, SqlType, Bool)]
getFieldsTypes stmt = map (\(name,sqlType,nullable,_) -> (name,sqlType,nullable)) (fields stmt)
findField :: String -> [FieldDef] -> FieldDef
findField name [] = error (name ++ "??")
findField name (fieldDef@(name',_,_,_):fields)
| name == name' = fieldDef
| otherwise = findField name fields
-----------------------------------------------------------------------------------------
-- helpers
-----------------------------------------------------------------------------------------
forEachRow :: (Statement -> s -> IO s) -> Statement -> s -> IO s
forEachRow f stmt s = do
success <- fetch stmt
if success then f stmt s >>= forEachRow f stmt else closeStatement stmt >> return s
forEachRow' :: (Statement -> IO ()) -> Statement -> IO ()
forEachRow' f stmt = do
success <- fetch stmt
if success then f stmt >> forEachRow' f stmt else closeStatement stmt
collectRows :: (Statement -> IO a) -> Statement -> IO [a]
collectRows f stmt = loop
where
loop = do
success <- fetch stmt
if success
then do
x <- f stmt
xs <- loop
return (x:xs)
else closeStatement stmt >> return []
--0-1914213734-1023371824=:11434--