Subversion

2lt

?curdirlinks? - Rev 1

?prevdifflink? - Blame


{-# OPTIONS -fglasgow-exts #-}

-----------------------------------------------------------------------------------------
{-| Module      :  Database.HSQL.ODBC
    Copyright   :  (c) Krasimir Angelov 2003
    License     :  BSD-style

    Maintainer  :  kr.angelov@gmail.com
    Stability   :  provisional
    Portability :  portable

    The module provides interface to ODBC
-}
-----------------------------------------------------------------------------------------

module Database.HSQL.ODBC(connect, driverConnect, module Database.HSQL) where

import Database.HSQL
import Database.HSQL.Types
import Data.Word(Word32, Word16)
import Data.Int(Int32, Int16)
import Data.Maybe
import Foreign
import Foreign.C
import Control.Monad(unless)
import Control.Exception(throwDyn)
import Control.Concurrent.MVar
import System.IO.Unsafe
import System.Time
#ifdef DEBUG
import Debug.Trace
#endif

#include <time.h>
#include <HsODBC.h>

type SQLHANDLE = Ptr ()
type HENV = SQLHANDLE
type HDBC = SQLHANDLE
type HSTMT = SQLHANDLE
type HENVRef = ForeignPtr ()

type SQLSMALLINT  = #type SQLSMALLINT
type SQLUSMALLINT = #type SQLUSMALLINT
type SQLINTEGER   = #type SQLINTEGER
type SQLUINTEGER  = #type SQLUINTEGER
type SQLRETURN    = SQLSMALLINT
type SQLLEN       = SQLINTEGER
type SQLULEN      = SQLINTEGER

#ifdef mingw32_HOST_OS
#let CALLCONV = "stdcall"
#else
#let CALLCONV = "ccall"
#endif

foreign import #{CALLCONV} "HsODBC.h SQLAllocEnv" sqlAllocEnv :: Ptr HENV -> IO SQLRETURN
#ifdef mingw32_HOST_OS
foreign import ccall "HsODBC.h &my_sqlFreeEnv" sqlFreeEnv_p :: FunPtr (HENV -> IO ())
#else
foreign import ccall "HsODBC.h &SQLFreeEnv" sqlFreeEnv_p :: FunPtr (HENV -> IO ())
#endif
foreign import #{CALLCONV} "HsODBC.h SQLAllocConnect" sqlAllocConnect :: HENV -> Ptr HDBC -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLFreeConnect" sqlFreeConnect:: HDBC -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLConnect" sqlConnect :: HDBC -> CString -> Int -> CString -> Int -> CString -> Int -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLDriverConnect" sqlDriverConnect :: HDBC -> Ptr () -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> Ptr SQLSMALLINT -> SQLUSMALLINT -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLDisconnect" sqlDisconnect :: HDBC -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLAllocStmt" sqlAllocStmt :: HDBC -> Ptr HSTMT -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLFreeStmt" sqlFreeStmt :: HSTMT -> SQLUSMALLINT -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLNumResultCols" sqlNumResultCols :: HSTMT -> Ptr SQLUSMALLINT -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLDescribeCol" sqlDescribeCol :: HSTMT -> SQLUSMALLINT -> CString -> SQLSMALLINT -> Ptr SQLSMALLINT -> Ptr SQLSMALLINT -> Ptr SQLULEN -> Ptr SQLSMALLINT -> Ptr SQLSMALLINT -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLBindCol" sqlBindCol :: HSTMT -> SQLUSMALLINT -> SQLSMALLINT -> Ptr a -> SQLLEN -> Ptr SQLINTEGER -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLFetch" sqlFetch :: HSTMT -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLGetDiagRec" sqlGetDiagRec :: SQLSMALLINT -> SQLHANDLE -> SQLSMALLINT -> CString -> Ptr SQLINTEGER -> CString -> SQLSMALLINT -> Ptr SQLSMALLINT -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLExecDirect" sqlExecDirect :: HSTMT -> CString -> Int -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLSetConnectOption" sqlSetConnectOption :: HDBC -> SQLUSMALLINT -> SQLULEN -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLTransact" sqlTransact :: HENV -> HDBC -> SQLUSMALLINT -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLGetData" sqlGetData :: HSTMT -> SQLUSMALLINT -> SQLSMALLINT -> Ptr () -> SQLINTEGER -> Ptr SQLINTEGER -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLTables" sqlTables :: HSTMT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLColumns" sqlColumns :: HSTMT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> CString -> SQLSMALLINT -> IO SQLRETURN
foreign import #{CALLCONV} "HsODBC.h SQLMoreResults" sqlMoreResults :: HSTMT -> IO SQLRETURN
#if defined(MSSQL_ODBC)
foreign import #{CALLCONV} "HsODBC.h SQLSetStmtAttr" sqlSetStmtAttr :: HSTMT -> SQLINTEGER -> SQLINTEGER -> SQLINTEGER -> IO SQLRETURN
#endif

-----------------------------------------------------------------------------------------
-- routines for handling exceptions
-----------------------------------------------------------------------------------------

handleSqlResult :: SQLSMALLINT -> SQLHANDLE -> SQLRETURN -> IO ()
handleSqlResult handleType handle res
        | res == (#const SQL_SUCCESS) || res == (#const SQL_NO_DATA) = return ()
        | res == (#const SQL_SUCCESS_WITH_INFO) = do
#ifdef DEBUG
            e <- getSqlError
            putTraceMsg (show e)
#else
            return ()
#endif
        | res == (#const SQL_INVALID_HANDLE) = throwDyn SqlInvalidHandle
        | res == (#const SQL_STILL_EXECUTING) = throwDyn SqlStillExecuting
        | res == (#const SQL_NEED_DATA) = throwDyn SqlNeedData
        | res == (#const SQL_ERROR) = do
            e <- getSqlError
            throwDyn e
        | otherwise = error (show res)
        where
          getSqlError =
            allocaBytes 256 $ \pState ->
            alloca $ \pNative ->
            allocaBytes 256 $ \pMsg ->
            alloca $ \pTextLen -> do
                  res <- sqlGetDiagRec handleType handle 1 pState pNative pMsg 256 pTextLen
                  if res == (#const SQL_NO_DATA)
                    then return SqlNoData
                    else do
                      state  <- peekCString pState
                      native <- peek pNative
                      msg    <- peekCString pMsg
                      return (SqlError {seState=state, seNativeError=fromIntegral native, seErrorMsg=msg})

-----------------------------------------------------------------------------------------
-- keeper of HENV
-----------------------------------------------------------------------------------------

{-# NOINLINE myEnvironment #-}
myEnvironment :: HENVRef
myEnvironment = unsafePerformIO $ alloca $ \ (phEnv :: Ptr HENV) -> do
        res <- sqlAllocEnv phEnv
        hEnv <- peek phEnv
        handleSqlResult 0 nullPtr res
        newForeignPtr sqlFreeEnv_p hEnv

-----------------------------------------------------------------------------------------
-- Connect/Disconnect
-----------------------------------------------------------------------------------------

-- | Makes a new connection to the ODBC data source
connect :: String               -- ^ Data source name
        -> String               -- ^ User identifier
        -> String               -- ^ Authentication string (password)
        -> IO Connection        -- ^ the returned value represents the new connection
connect server user authentication = connectHelper $ \hDBC ->
        withCString server $ \pServer ->
        withCString user $ \pUser ->
        withCString authentication $ \pAuthentication ->
        sqlConnect hDBC pServer (#const SQL_NTS) pUser (#const SQL_NTS) pAuthentication (#const SQL_NTS)

-- | 'driverConnect' is an alternative to 'connect'. It supports data sources that 
-- require more connection information than the three arguments in 'connect'
-- and data sources that are not defined in the system information.
driverConnect :: String               -- ^ Connection string
              -> IO Connection        -- ^ the returned value represents the new connection
driverConnect connString = connectHelper $ \hDBC -> 
        withCString connString $ \pConnString ->
        allocaBytes 1024 $ \pOutConnString ->
        alloca $ \pLen ->
        sqlDriverConnect hDBC nullPtr pConnString (#const SQL_NTS) pOutConnString 1024 pLen (#const SQL_DRIVER_NOPROMPT)

connectHelper :: (HDBC -> IO SQLRETURN) -> IO Connection
connectHelper connectFunction = withForeignPtr myEnvironment $ \hEnv -> do
        hDBC <- alloca $ \ (phDBC :: Ptr HDBC) -> do
            res <- sqlAllocConnect hEnv phDBC
            handleSqlResult (#const SQL_HANDLE_ENV) hEnv res
            peek phDBC
        res <- connectFunction hDBC
        handleSqlResult (#const SQL_HANDLE_DBC) hDBC res
        refFalse <- newMVar False
        let connection = (Connection
                        { connDisconnect = disconnect hDBC
                        , connExecute    = execute hDBC
                        , connQuery      = query connection hDBC
                        , connTables     = tables connection hDBC
                        , connDescribe   = describe connection hDBC
                        , connBeginTransaction = beginTransaction myEnvironment hDBC
                        , connCommitTransaction = commitTransaction myEnvironment hDBC
                        , connRollbackTransaction = rollbackTransaction myEnvironment hDBC
                        , connClosed     = refFalse
                        })
        return connection
        where
                disconnect :: HDBC -> IO ()
                disconnect hDBC = do
                        sqlDisconnect hDBC >>= handleSqlResult (#const SQL_HANDLE_DBC) hDBC
                        sqlFreeConnect hDBC >>= handleSqlResult (#const SQL_HANDLE_DBC) hDBC

                execute :: HDBC -> String -> IO ()
                execute hDBC query = allocaBytes (#const sizeof(HSTMT)) $
                  \pStmt -> do
                        res <- sqlAllocStmt hDBC pStmt
                        handleSqlResult (#const SQL_HANDLE_DBC) hDBC res
                        hSTMT <- peek pStmt
                        withCStringLen query $ \(pQuery,len) -> do
                             res <- sqlExecDirect hSTMT pQuery len
                             handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        res <- sqlFreeStmt hSTMT (#const SQL_DROP)
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res

                stmtBufferSize = 256

                withStatement :: Connection -> HDBC -> (HSTMT -> IO SQLRETURN) -> IO Statement
                withStatement connection hDBC f = 
                        allocaBytes (#const sizeof(FIELD)) $ \pFIELD -> do
                        res <- sqlAllocStmt hDBC ((#ptr FIELD, hSTMT) pFIELD)
                        handleSqlResult (#const SQL_HANDLE_DBC) hDBC res
                        hSTMT <- (#peek FIELD, hSTMT) pFIELD
                        let handleResult res = handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
#if defined(MSSQL_ODBC)
                        sqlSetStmtAttr hSTMT (#const SQL_ATTR_ROW_ARRAY_SIZE) 2 (#const SQL_IS_INTEGER)
                        sqlSetStmtAttr hSTMT (#const SQL_ATTR_CURSOR_TYPE) (#const SQL_CURSOR_STATIC) (#const SQL_IS_INTEGER)
#endif
                        f hSTMT >>= handleResult
                        fields <- moveToFirstResult hSTMT pFIELD
                        buffer <- mallocBytes (fromIntegral stmtBufferSize)
                        refFalse <- newMVar False
                        let statement = Statement
                                { stmtConn   = connection
                                , stmtClose  = closeStatement hSTMT buffer
                                , stmtFetch  = fetch hSTMT
                                , stmtGetCol = getColValue hSTMT buffer
                                , stmtFields = fields
                                , stmtClosed = refFalse
                                }
                        return statement
                        where
                                moveToFirstResult :: HSTMT -> Ptr a -> IO [FieldDef]
                                moveToFirstResult hSTMT pFIELD = do
                                  res <- sqlNumResultCols hSTMT ((#ptr FIELD, fieldsCount) pFIELD)
                                  handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                                  count <- (#peek FIELD, fieldsCount) pFIELD
                                  if count == 0
                                    then do
#if defined(MSSQL_ODBC)
                                      sqlSetStmtAttr hSTMT (#const SQL_ATTR_ROW_ARRAY_SIZE) 2 (#const SQL_IS_INTEGER)
                                      sqlSetStmtAttr hSTMT (#const SQL_ATTR_CURSOR_TYPE) (#const SQL_CURSOR_STATIC) (#const SQL_IS_INTEGER)
#endif
                                      res <- sqlMoreResults hSTMT
                                      handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                                      if res == (#const SQL_NO_DATA)
                                        then return []
                                        else moveToFirstResult hSTMT pFIELD
                                    else
                                      getFieldDefs hSTMT pFIELD 1 count

                                getFieldDefs :: HSTMT -> Ptr a -> SQLUSMALLINT -> SQLUSMALLINT -> IO [FieldDef]
                                getFieldDefs hSTMT pFIELD n count
                                        | n > count  = return []
                                        | otherwise = do
                                                res <- sqlDescribeCol hSTMT n ((#ptr FIELD, fieldName) pFIELD) (#const FIELD_NAME_LENGTH) ((#ptr FIELD, NameLength) pFIELD) ((#ptr FIELD, DataType) pFIELD) ((#ptr FIELD, ColumnSize) pFIELD) ((#ptr FIELD, DecimalDigits) pFIELD) ((#ptr FIELD, Nullable) pFIELD)
                                                handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                                                name <- peekCString ((#ptr FIELD, fieldName) pFIELD)
                                                dataType <- (#peek FIELD, DataType) pFIELD
                                                columnSize <- (#peek FIELD, ColumnSize) pFIELD
                                                decimalDigits <- (#peek FIELD, DecimalDigits) pFIELD
                                                (nullable :: SQLSMALLINT) <- (#peek FIELD, Nullable) pFIELD
                                                let sqlType = mkSqlType dataType columnSize decimalDigits
                                                fields <- getFieldDefs hSTMT pFIELD (n+1) count
                                                return ((name,sqlType,toBool nullable):fields)

                mkSqlType :: SQLSMALLINT -> SQLULEN -> SQLSMALLINT -> SqlType
                mkSqlType (#const SQL_CHAR)         size    _    = SqlChar (fromIntegral size)
                mkSqlType (#const SQL_VARCHAR)      size    _    = SqlVarChar (fromIntegral size)
                mkSqlType (#const SQL_LONGVARCHAR)  size    _    = SqlLongVarChar (fromIntegral size)
                mkSqlType (#const SQL_DECIMAL)      size    prec = SqlDecimal (fromIntegral size) (fromIntegral prec)
                mkSqlType (#const SQL_NUMERIC)      size    prec = SqlNumeric (fromIntegral size) (fromIntegral prec)
                mkSqlType (#const SQL_SMALLINT)     _       _    = SqlSmallInt
                mkSqlType (#const SQL_INTEGER)      _       _    = SqlInteger
                mkSqlType (#const SQL_REAL)         _       _    = SqlReal
                -- From: http://msdn.microsoft.com/library/en-us/odbc/htm/odappdpr_2.asp
                -- "Depending on the implementation, the precision of SQL_FLOAT can be either 24 or 53:
                -- if it is 24, the SQL_FLOAT data type is the same as SQL_REAL;
                -- if it is 53, the SQL_FLOAT data type is the same as SQL_DOUBLE."
                mkSqlType (#const SQL_FLOAT)        _        _    = SqlFloat
                mkSqlType (#const SQL_DOUBLE)           _       _    = SqlDouble
                mkSqlType (#const SQL_BIT)          _           _    = SqlBit
                mkSqlType (#const SQL_TINYINT)      _           _    = SqlTinyInt
                mkSqlType (#const SQL_BIGINT)       _           _    = SqlBigInt
                mkSqlType (#const SQL_BINARY)       size    _    = SqlBinary (fromIntegral size)
                mkSqlType (#const SQL_VARBINARY)    size    _    = SqlVarBinary (fromIntegral size)
                mkSqlType (#const SQL_LONGVARBINARY)size    _    = SqlLongVarBinary (fromIntegral size)
                mkSqlType (#const SQL_DATE)         _           _    = SqlDate
                mkSqlType (#const SQL_TIME)         _           _    = SqlTime
                mkSqlType (#const SQL_TIMESTAMP)        _       _    = SqlDateTime
                mkSqlType (#const SQL_WCHAR)        size        _    = SqlWChar (fromIntegral size)
                mkSqlType (#const SQL_WVARCHAR)     size    _    = SqlWVarChar (fromIntegral size)
                mkSqlType (#const SQL_WLONGVARCHAR)     size    _    = SqlWLongVarChar (fromIntegral size)
                mkSqlType tp                        _       _    = SqlUnknown (fromIntegral tp)

                query :: Connection -> HDBC -> String -> IO Statement
                query connection hDBC q = withStatement connection hDBC doQuery
                        where doQuery hSTMT = withCStringLen q (uncurry (sqlExecDirect hSTMT))

                beginTransaction myEnvironment hDBC = do
                        sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_OFF)
                        return ()

                commitTransaction myEnvironment hDBC = withForeignPtr myEnvironment $ \hEnv -> do
                        sqlTransact hEnv hDBC (#const SQL_COMMIT)
                        sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_ON)
                        return ()

                rollbackTransaction myEnvironment hDBC = withForeignPtr myEnvironment $ \hEnv -> do
                        sqlTransact hEnv hDBC (#const SQL_ROLLBACK)
                        sqlSetConnectOption hDBC (#const SQL_AUTOCOMMIT) (#const SQL_AUTOCOMMIT_ON)
                        return ()

                tables :: Connection -> HDBC -> IO [String]
                tables connection hDBC = do
                        stmt <- withStatement connection hDBC sqlTables'
                        -- SQLTables returns (column names may vary):
                        -- Column name     #   Type
                        -- TABLE_NAME      3   VARCHAR
                        collectRows (\s -> getFieldValue s "TABLE_NAME") stmt
                        where sqlTables' hSTMT = sqlTables hSTMT nullPtr 0 nullPtr 0 nullPtr 0 nullPtr 0

                describe :: Connection -> HDBC -> String -> IO [FieldDef]
                describe connection hDBC table = do
                        stmt <- withStatement connection hDBC (sqlColumns' table)
                        collectRows getColumnInfo stmt
                        where
                                sqlColumns' table hSTMT =
                                  withCStringLen table (\(pTable,len) ->
                                        sqlColumns hSTMT nullPtr 0 nullPtr 0 pTable (fromIntegral len) nullPtr 0)
                                        -- SQLColumns returns (column names may vary):
                                        -- Column name     #   Type
                                        -- COLUMN_NAME     4   Varchar not NULL
                                        -- DATA_TYPE       5   Smallint not NULL
                                        -- COLUMN_SIZE     7   Integer
                                        -- DECIMAL_DIGITS  9   Smallint
                                        -- NULLABLE       11   Smallint not NULL

                                getColumnInfo stmt = do
                                  column_name <- getFieldValue stmt "COLUMN_NAME"
                                  (data_type::Int) <- getFieldValue stmt "DATA_TYPE"
                                  (column_size::Int) <- getFieldValue' stmt "COLUMN_SIZE" 0
                                  (decimal_digits::Int) <- getFieldValue' stmt "DECIMAL_DIGITS" 0
                                  (nullable::Int) <- getFieldValue stmt "NULLABLE"
                                  let sqlType = mkSqlType (fromIntegral data_type) (fromIntegral column_size) (fromIntegral decimal_digits)
                                  return (column_name, sqlType, toBool nullable)

                fetch :: HSTMT -> IO Bool
                fetch hSTMT = do
                        res <- sqlFetch hSTMT
                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                        return (res /= (#const SQL_NO_DATA))

                getColValue :: HSTMT -> CString -> Int -> FieldDef -> (FieldDef -> CString -> Int -> IO a) -> IO a
                getColValue hSTMT buffer colNumber fieldDef f = do
                        (res,len_or_ind) <- getData buffer (fromIntegral stmtBufferSize)
                        if len_or_ind == (#const SQL_NULL_DATA)
                          then f fieldDef nullPtr 0
                          else if res == (#const SQL_SUCCESS_WITH_INFO)
                                 then getLongData len_or_ind
                                 else f fieldDef buffer (fromIntegral len_or_ind)
                        where
                                getData :: CString -> SQLINTEGER -> IO (SQLRETURN, SQLINTEGER)
                                getData buffer size = alloca $ \lenP -> do
                                        res <- sqlGetData hSTMT (fromIntegral colNumber+1) (#const SQL_C_CHAR) (castPtr buffer) size lenP
                                        handleSqlResult (#const SQL_HANDLE_STMT) hSTMT res
                                        len_or_ind <- peek lenP
                                        return (res, len_or_ind)

                                -- gets called only when there is more data than would
                                -- fit in the normal buffer. This call to
                                -- SQLGetData() will fetch the rest of the data.
                                -- We create a new buffer big enough to hold the
                                -- old and the new data, copy the old data into
                                -- it and put the new data in buffer after the old.
                                getLongData len = allocaBytes (fromIntegral newBufSize) $ \newBuf -> do
                                        copyBytes newBuf buffer stmtBufferSize
                                -- The last byte of the old data with always be null,
                                        -- so it is overwritten with the first byte of  the new data.
                                        let newDataStart = newBuf `plusPtr` (stmtBufferSize - 1)
                                            newDataLen = newBufSize - (fromIntegral stmtBufferSize - 1)
                                        (res,_) <- getData newDataStart newDataLen
                                        f fieldDef newBuf (fromIntegral newBufSize-1)
                                        where
                                                newBufSize = len+1 -- to allow for terminating null character

                closeStatement :: HSTMT -> CString -> IO ()
                closeStatement hSTMT buffer = do
                        free buffer
                        sqlFreeStmt hSTMT (#const SQL_DROP) >>= handleSqlResult (#const SQL_HANDLE_STMT) hSTMT

Theme by Vikram Singh | Powered by WebSVN v2.3.3