Subversion

2lt

?curdirlinks? - Rev 1

?prevdifflink? - Blame


-----------------------------------------------------------------------------------------
{-| Module      :  Database.HSQL.MySQL
    Copyright   :  (c) Krasimir Angelov 2003
    License     :  BSD-style

    Maintainer  :  ka2_mail@yahoo.com
    Stability   :  provisional
    Portability :  portable

    The module provides interface to MySQL database
-}
-----------------------------------------------------------------------------------------

module Database.HSQL.MySQL(connect, module Database.HSQL) where

import Database.HSQL
import Database.HSQL.Types
import Data.Dynamic
import Data.Bits
import Data.Char
import Foreign
import Foreign.C
import Control.Monad(when,unless)
import Control.Exception (throwDyn, finally)
import Control.Concurrent.MVar
import System.Time
import System.IO.Unsafe
import Text.ParserCombinators.ReadP
import Text.Read

#include <HsMySQL.h>

type MYSQL = Ptr ()
type MYSQL_RES = Ptr ()
type MYSQL_FIELD = Ptr ()
type MYSQL_ROW = Ptr CString
type MYSQL_LENGTHS = Ptr CULong

#ifdef mingw32_HOST_OS
#let CALLCONV = "stdcall"
#else
#let CALLCONV = "ccall"
#endif

foreign import #{CALLCONV} "HsMySQL.h mysql_init" mysql_init :: MYSQL -> IO MYSQL
foreign import #{CALLCONV} "HsMySQL.h mysql_real_connect" mysql_real_connect :: MYSQL -> CString -> CString -> CString -> CString -> CInt -> CString -> CInt -> IO MYSQL
foreign import #{CALLCONV} "HsMySQL.h mysql_close" mysql_close :: MYSQL -> IO ()
foreign import #{CALLCONV} "HsMySQL.h mysql_errno" mysql_errno :: MYSQL -> IO CInt
foreign import #{CALLCONV} "HsMySQL.h mysql_error" mysql_error :: MYSQL -> IO CString
foreign import #{CALLCONV} "HsMySQL.h mysql_query" mysql_query :: MYSQL -> CString -> IO CInt
foreign import #{CALLCONV} "HsMySQL.h mysql_use_result" mysql_use_result :: MYSQL -> IO MYSQL_RES
foreign import #{CALLCONV} "HsMySQL.h mysql_fetch_field" mysql_fetch_field :: MYSQL_RES -> IO MYSQL_FIELD
foreign import #{CALLCONV} "HsMySQL.h mysql_free_result" mysql_free_result :: MYSQL_RES -> IO ()
foreign import #{CALLCONV} "HsMySQL.h mysql_fetch_row" mysql_fetch_row :: MYSQL_RES -> IO MYSQL_ROW
foreign import #{CALLCONV} "HsMySQL.h mysql_fetch_lengths" mysql_fetch_lengths :: MYSQL_RES -> IO MYSQL_LENGTHS
foreign import #{CALLCONV} "HsMySQL.h mysql_list_tables" mysql_list_tables :: MYSQL -> CString -> IO MYSQL_RES
foreign import #{CALLCONV} "HsMySQL.h mysql_list_fields" mysql_list_fields :: MYSQL -> CString -> CString -> IO MYSQL_RES
foreign import #{CALLCONV} "HsMySQL.h mysql_next_result" mysql_next_result :: MYSQL -> IO CInt



-----------------------------------------------------------------------------------------
-- routines for handling exceptions
-----------------------------------------------------------------------------------------

handleSqlError :: MYSQL -> IO a
handleSqlError pMYSQL = do
        errno <- mysql_errno pMYSQL
        errMsg <- mysql_error pMYSQL >>= peekCString
        throwDyn (SqlError "" (fromIntegral errno) errMsg)

-----------------------------------------------------------------------------------------
-- Connect/Disconnect
-----------------------------------------------------------------------------------------

-- | Makes a new connection to the database server.
connect :: String   -- ^ Server name
        -> String   -- ^ Database name
        -> String   -- ^ User identifier
        -> String   -- ^ Authentication string (password)
        -> IO Connection
connect server database user authentication = do
        pMYSQL <- mysql_init nullPtr
        pServer <- newCString server
        pDatabase <- newCString database
        pUser <- newCString user
        pAuthentication <- newCString authentication
        res <- mysql_real_connect pMYSQL pServer pUser pAuthentication pDatabase 0 nullPtr (#const MYSQL_DEFAULT_CONNECT_FLAGS)
        free pServer
        free pDatabase
        free pUser
        free pAuthentication
        when (res == nullPtr) (handleSqlError pMYSQL)
        refFalse <- newMVar False
        let connection = Connection
                { connDisconnect = mysql_close pMYSQL
                , connExecute = execute pMYSQL
                , connQuery = query connection pMYSQL
                , connTables = tables connection pMYSQL
                , connDescribe = describe connection pMYSQL
                , connBeginTransaction = execute pMYSQL "begin"
                , connCommitTransaction = execute pMYSQL "commit"
                , connRollbackTransaction = execute pMYSQL "rollback"
                , connClosed = refFalse
                }
        return connection
        where
                execute :: MYSQL -> String -> IO ()
                execute pMYSQL query = do
                        res <- withCString query (mysql_query pMYSQL)
                        when (res /= 0) (handleSqlError pMYSQL)

                withStatement :: Connection -> MYSQL -> MYSQL_RES -> IO Statement
                withStatement conn pMYSQL pRes = do
                        currRow  <- newMVar (nullPtr, nullPtr)
                        refFalse <- newMVar False
                        if (pRes == nullPtr)
                          then do
                                errno <- mysql_errno pMYSQL
                                when (errno /= 0) (handleSqlError pMYSQL)
                                return (Statement
                                      { stmtConn   = conn
                                      , stmtClose  = return ()
                                      , stmtFetch  = fetch pRes currRow
                                      , stmtGetCol = getColValue currRow
                                      , stmtFields = []
                                      , stmtClosed = refFalse
                                      })
                          else do
                                fieldDefs <- getFieldDefs pRes
                                return (Statement
                                      { stmtConn   = conn
                                      , stmtClose  = mysql_free_result pRes
                                      , stmtFetch  = fetch pRes currRow
                                      , stmtGetCol = getColValue currRow
                                      , stmtFields = fieldDefs
                                      , stmtClosed = refFalse
                                      })
                        where
                                getFieldDefs pRes = do
                                        pField <- mysql_fetch_field pRes
                                        if pField == nullPtr
                                          then return []
                                          else do
                                                name <- (#peek MYSQL_FIELD, name) pField >>= peekCString
                                                dataType <-  (#peek MYSQL_FIELD, type) pField
                                                columnSize <-  (#peek MYSQL_FIELD, length) pField
                                                flags <-  (#peek MYSQL_FIELD, flags) pField
                                                decimalDigits <-  (#peek MYSQL_FIELD, decimals) pField
                                                let sqlType = mkSqlType dataType columnSize decimalDigits
                                                defs <- getFieldDefs pRes
                                                return ((name,sqlType,((flags :: Int) .&. (#const NOT_NULL_FLAG)) == 0):defs)

                                mkSqlType :: Int -> Int -> Int -> SqlType
                                mkSqlType (#const FIELD_TYPE_STRING)     size _    = SqlChar size
                                mkSqlType (#const FIELD_TYPE_VAR_STRING) size _    = SqlVarChar size
                                mkSqlType (#const FIELD_TYPE_DECIMAL)    size prec = SqlNumeric size prec
                                mkSqlType (#const FIELD_TYPE_SHORT)      _    _    = SqlSmallInt
                                mkSqlType (#const FIELD_TYPE_INT24)      _    _    = SqlMedInt
                                mkSqlType (#const FIELD_TYPE_LONG)       _    _    = SqlInteger
                                mkSqlType (#const FIELD_TYPE_FLOAT)      _    _    = SqlReal
                                mkSqlType (#const FIELD_TYPE_DOUBLE)     _    _    = SqlDouble
                                mkSqlType (#const FIELD_TYPE_TINY)       _    _    = SqlTinyInt
                                mkSqlType (#const FIELD_TYPE_LONGLONG)   _    _    = SqlBigInt
                                mkSqlType (#const FIELD_TYPE_DATE)       _    _    = SqlDate
                                mkSqlType (#const FIELD_TYPE_TIME)       _    _    = SqlTime
                                mkSqlType (#const FIELD_TYPE_TIMESTAMP)  _    _    = SqlTimeStamp
                                mkSqlType (#const FIELD_TYPE_DATETIME)   _    _    = SqlDateTime
                                mkSqlType (#const FIELD_TYPE_YEAR)       _    _    = SqlYear
                                mkSqlType (#const FIELD_TYPE_BLOB)       _    _    = SqlBLOB
                                mkSqlType (#const FIELD_TYPE_SET)        _    _    = SqlSET
                                mkSqlType (#const FIELD_TYPE_ENUM)       _    _    = SqlENUM
                                mkSqlType tp                             _    _    = SqlUnknown tp

                query :: Connection -> MYSQL -> String -> IO Statement
                query conn pMYSQL query = do
                        res <- withCString query (mysql_query pMYSQL)
                        when (res /= 0) (handleSqlError pMYSQL)
                        pRes <- getFirstResult pMYSQL
                        withStatement conn pMYSQL pRes
                        where
                          getFirstResult :: MYSQL -> IO MYSQL_RES
                          getFirstResult pMYSQL = do
                            pRes <- mysql_use_result pMYSQL
                            if pRes == nullPtr
                              then do
                                res <- mysql_next_result pMYSQL
                                if res == 0
                                  then getFirstResult pMYSQL
                                  else return nullPtr
                              else return pRes

                fetch :: MYSQL_RES -> MVar (MYSQL_ROW, MYSQL_LENGTHS) -> IO Bool
                fetch pRes currRow
                        | pRes == nullPtr = return False
                        | otherwise = modifyMVar currRow $ \(pRow, pLengths) -> do
                                pRow <- mysql_fetch_row pRes
                                pLengths <- mysql_fetch_lengths pRes
                                return ((pRow, pLengths), pRow /= nullPtr)

                getColValue :: MVar (MYSQL_ROW, MYSQL_LENGTHS) -> Int -> FieldDef -> (FieldDef -> CString -> Int -> IO a) -> IO a
                getColValue currRow colNumber fieldDef f = do
                        (row, lengths) <- readMVar currRow
                        pValue <- peekElemOff row colNumber
                        len <- fmap fromIntegral (peekElemOff lengths colNumber)
                        f fieldDef pValue len

                tables :: Connection -> MYSQL -> IO [String]
                tables conn pMYSQL = do
                        pRes <- mysql_list_tables pMYSQL nullPtr
                        stmt <- withStatement conn pMYSQL pRes
                        -- SQLTables returns:
                        -- Column name     #   Type
                        -- Tables_in_xx      0   VARCHAR
                        collectRows (\stmt -> do
                                mb_v <- stmtGetCol stmt 0 ("Tables", SqlVarChar 0, False) fromSqlCStringLen
                                return (case mb_v of { Nothing -> ""; Just a -> a })) stmt

                describe :: Connection -> MYSQL -> String -> IO [FieldDef]
                describe conn pMYSQL table = do
                        pRes <- withCString table (\table -> mysql_list_fields pMYSQL table nullPtr)
                        stmt <- withStatement conn pMYSQL pRes
                        return (getFieldsTypes stmt)

Theme by Vikram Singh | Powered by WebSVN v2.3.3