Subversion

2lt

?curdirlinks? - Rev 1

?prevdifflink? - Blame


-----------------------------------------------------------------------------------------
{-| Module      :  Database.HSQL.Oracle
    Copyright   :  (c) Krasimir Angelov 2003
    License     :  BSD-style

    Maintainer  :  kr.angelov@gmail.com
    Stability   :  provisional
    Portability :  portable

    The module provides interface to Oracle
-}
-----------------------------------------------------------------------------------------

module Database.HSQL.Oracle(connect, module Database.HSQL) where

import Database.HSQL
import Database.HSQL.Types
import Foreign
import Foreign.C
import Foreign.Concurrent as FC
import Control.Concurrent.MVar
import Control.Exception(throwDyn)
import Data.Word

#include <HsOCI.h>

type OCIHandle = Ptr ()
type OCIEnv    = OCIHandle
type OCIError  = OCIHandle
type OCISvcCtx = OCIHandle
type OCIStmt   = OCIHandle
type OCIParam  = OCIHandle
type OCIDefine = OCIHandle
type OCIDescribe=OCIHandle

type OCIEnvRef = ForeignPtr ()

foreign import ccall "OCIEnvCreate" ociEnvCreate :: Ptr OCIEnv -> CInt -> Ptr a -> FunPtr a -> FunPtr a -> FunPtr a -> CInt -> Ptr (Ptr a) -> IO CInt
foreign import ccall "OCITerminate" ociTerminate :: CInt -> IO CInt
foreign import ccall "OCIHandleAlloc" ociHandleAlloc :: OCIHandle -> Ptr OCIHandle -> CInt -> CInt -> Ptr a -> IO CInt
foreign import ccall "OCIHandleFree" ociHandleFree :: OCIHandle -> CInt -> IO CInt
foreign import ccall "OCIErrorGet" ociErrorGet :: OCIHandle -> CInt -> CString -> Ptr CInt -> CString -> CInt -> CInt -> IO CInt

foreign import ccall "OCILogon" ociLogon :: OCIEnv -> OCIError -> Ptr OCISvcCtx -> CString -> CInt -> CString -> CInt -> CString -> CInt -> IO CInt
foreign import ccall "OCILogoff" ociLogoff :: OCISvcCtx -> OCIError -> IO CInt

foreign import ccall "OCIStmtPrepare" ociStmtPrepare :: OCIStmt -> OCIError -> CString -> CInt -> CInt -> CInt -> IO CInt
foreign import ccall "OCIStmtExecute" ociStmtExecute :: OCISvcCtx -> OCIStmt -> OCIError -> CInt -> CInt -> OCIHandle -> OCIHandle -> CInt -> IO CInt
foreign import ccall "OCIStmtFetch2" ociStmtFetch2 :: OCIStmt -> OCIError -> CInt -> CInt -> CInt -> CInt -> IO CInt
foreign import ccall "OCIDefineByPos" ociDefineByPos :: OCIStmt -> Ptr OCIDefine -> OCIError -> CInt -> Ptr a -> CInt -> CShort -> Ptr CShort -> Ptr CShort -> Ptr CShort -> CInt -> IO CInt

foreign import ccall "OCIParamGet" ociParamGet :: OCIHandle -> CInt -> OCIError -> Ptr OCIParam -> CInt -> IO CInt
foreign import ccall "OCIAttrGet" ociAttrGet :: OCIParam -> CInt -> Ptr a -> Ptr CInt -> CInt -> OCIError -> IO CInt
foreign import ccall "OCIDescribeAny" ociDescribeAny :: OCISvcCtx -> OCIError -> Ptr () -> CInt -> Word8 -> Word8 -> Word8 -> OCIDescribe -> IO CInt

foreign import ccall "OCIDescriptorFree" ociDescriptorFree :: OCIHandle -> CInt -> IO CInt

foreign import ccall "OCITransStart" ociTransStart :: OCISvcCtx -> OCIError -> Word8 -> CInt -> IO CInt
foreign import ccall "OCITransCommit" ociTransCommit :: OCISvcCtx -> OCIError -> CInt -> IO CInt
foreign import ccall "OCITransRollback" ociTransRollback :: OCISvcCtx -> OCIError -> CInt -> IO CInt

foreign import ccall "strlen" strlen :: CString -> IO CInt

-----------------------------------------------------------------------------------------
-- keeper of OCIEnv
-----------------------------------------------------------------------------------------

{-# NOINLINE myEnvironment #-}
myEnvironment :: OCIEnvRef
myEnvironment = unsafePerformIO $ alloca $ \ pOCIEnv -> do
  ociEnvCreate pOCIEnv (#const OCI_DEFAULT) nullPtr nullFunPtr nullFunPtr nullFunPtr 0 nullPtr >>= handleSqlResult nullPtr
  env <- peek pOCIEnv
  FC.newForeignPtr env terminate
  where
    terminate = ociTerminate (#const OCI_DEFAULT) >>= handleSqlResult nullPtr

-----------------------------------------------------------------------------------------
-- error handling
-----------------------------------------------------------------------------------------

handleSqlResult err res
  | res == (#const OCI_SUCCESS) || res == (#const OCI_NO_DATA) = return ()
  | res == (#const OCI_SUCCESS_WITH_INFO) = do
#ifdef DEBUG
      e <- getSqlError
      putTraceMsg (show e)
#else
      return ()
#endif
  | res == (#const OCI_INVALID_HANDLE)  = throwDyn SqlInvalidHandle
  | res == (#const OCI_STILL_EXECUTING) = throwDyn SqlStillExecuting
  | res == (#const OCI_NEED_DATA)       = throwDyn SqlNeedData
  | res == (#const OCI_ERROR)           = getSqlError >>= throwDyn
  | otherwise = error (show res)
  where
    stringBufferLen = 1024
    
    getSqlError =
      alloca $ \pErrCode ->
      allocaBytes stringBufferLen $ \pErrMsg -> do
        rc <- ociErrorGet err 1 nullPtr pErrCode pErrMsg (fromIntegral stringBufferLen) (#const OCI_HTYPE_ERROR)
        if rc < 0
          then return SqlNoData
          else do
            msg <- peekCString pErrMsg
            errCode <- peek pErrCode
            return (SqlError {seState="", seNativeError=fromIntegral errCode, seErrorMsg=msg})

-- | Makes a new connection to the Oracle service
connect :: String               -- ^ Service name
        -> String               -- ^ User identifier
        -> String               -- ^ Password
        -> IO Connection        -- ^ the returned value represents the new connection
connect service user pwd =
  withForeignPtr myEnvironment $ \env ->
  withCStringLen user    $ \(user, user_len) ->
  withCStringLen pwd     $ \(pwd, pwd_len) ->
  withCStringLen service $ \(service, service_len) ->
  alloca $ \pError -> do
  alloca $ \pSvcCtx -> do
    ociHandleAlloc env pError (#const OCI_HTYPE_ERROR) 0 nullPtr >>= handleSqlResult nullPtr
    err <- peek pError
    res <- ociLogon env err pSvcCtx user (fromIntegral user_len) pwd (fromIntegral pwd_len) service (fromIntegral service_len)
    handleSqlResult err res
    svcCtx <- peek pSvcCtx
    refFalse <- newMVar False
    let connection = (Connection
                        { connDisconnect = disconnect svcCtx err
                        , connExecute    = execute myEnvironment svcCtx err
                        , connQuery      = query connection myEnvironment svcCtx err
                        , connTables     = tables env svcCtx err
                        , connDescribe   = describe env svcCtx err
                        , connBeginTransaction = beginTransaction myEnvironment svcCtx err
                        , connCommitTransaction = commitTransaction myEnvironment svcCtx err 
                        , connRollbackTransaction = rollbackTransaction myEnvironment svcCtx err
                        , connClosed     = refFalse
                        })
    return connection
  where
    disconnect svcCtx err = do
      ociLogoff svcCtx err >>= handleSqlResult err
      ociHandleFree err (#const OCI_HTYPE_ERROR) >>= handleSqlResult err

    execute envRef svcCtx err query =
      withForeignPtr envRef $ \env ->
      withCStringLen query $ \(query,query_len) ->
      alloca $ \pStmt -> do
        ociHandleAlloc env pStmt (#const OCI_HTYPE_STMT) 0 nullPtr >>= handleSqlResult err
        stmt <- peek pStmt
        ociStmtPrepare stmt err query (fromIntegral query_len) (#const OCI_NTV_SYNTAX) (#const OCI_DEFAULT) >>= handleSqlResult err
        ociStmtExecute svcCtx stmt err 1 0 nullPtr nullPtr (#const OCI_DEFAULT) >>= handleSqlResult err
        ociHandleFree stmt (#const OCI_HTYPE_STMT) >>= handleSqlResult err

    query connection envRef svcCtx err query = 
      withForeignPtr envRef $ \env ->
      withCStringLen query $ \(query,query_len) ->
      alloca $ \pStmt -> do
        ociHandleAlloc env pStmt (#const OCI_HTYPE_STMT) 0 nullPtr >>= handleSqlResult err
        stmt <- peek pStmt
        ociStmtPrepare stmt err query (fromIntegral query_len) (#const OCI_NTV_SYNTAX) (#const OCI_DEFAULT) >>= handleSqlResult err
        ociStmtExecute svcCtx stmt err 0 0 nullPtr nullPtr (#const OCI_DEFAULT) >>= handleSqlResult err
        fields <- allocaBytes (#const (sizeof(FIELD_DEF))) (getFieldDefs stmt 1)
        let offsets_arr_size = fromIntegral (length fields * sizeOf offsets_arr_size) :: CInt
        buffer <- mallocBytes (fromIntegral (foldr ((+) . sqlType2Size) offsets_arr_size fields))
        definePositions stmt err buffer 0 offsets_arr_size fields
        refFalse <- newMVar False
        let statement = Statement
                { stmtConn   = connection
                , stmtClose  = closeStatement stmt buffer err
                , stmtFetch  = fetch stmt err
                , stmtGetCol = getColValue buffer
                , stmtFields = fields
                , stmtClosed = refFalse
                }
        return statement
        where
          getFieldDefs stmt counter buffer = do
            res <- ociParamGet stmt (#const OCI_HTYPE_STMT) err ((#ptr FIELD_DEF, par) buffer) counter
            if res == (#const OCI_SUCCESS)
              then do field  <- getFieldDef err buffer
                      fields <- getFieldDefs stmt (counter+1) buffer
                      return (field:fields)
              else return []

          sqlType2Size :: FieldDef -> CInt
          sqlType2Size (_,tp,_) =
            case tp of
              SqlVarChar n   -> fromIntegral n+1
              SqlNumeric p s -> fromIntegral (p+s+3) -- The value precision plus optional positions for '.', '-' and
                                                     -- one position for the '\0' character at end of the string.
              SqlInteger     -> 16 -- 12 digits are enough (maxBound :: Int) has 10 digits.
                                   -- in addition we may need one position for '-' and one
                                   -- for the '\0' character at end of the string.
              SqlFloat       -> 100
              SqlDate        -> 100
              SqlTime        -> 100
              SqlTimeTZ      -> 100
              SqlTimeStamp   -> 100
              SqlText        -> 100
              SqlUnknown _   -> 0

          definePositions stmt err buffer pos offset [] = return ()
          definePositions stmt err buffer pos offset (field:fields) =
            alloca $ \pDef -> do
              let size = sqlType2Size field
              poke (castPtr buffer `advancePtr` fromIntegral pos) offset
              ociDefineByPos stmt pDef err (pos+1) (buffer `plusPtr` fromIntegral offset) size (#const SQLT_STR) nullPtr nullPtr nullPtr (#const OCI_DEFAULT)
              definePositions stmt err buffer (pos+1) (offset+size) fields

    mkSqlType :: (#type OCITypeCode) -> (#type ub2) -> (#type ub1) -> (#type ub1) -> SqlType
    mkSqlType (#const SQLT_CHR)       size _    _     = SqlVarChar (fromIntegral size)
    mkSqlType (#const SQLT_AFC)       size _    _     = SqlChar    (fromIntegral size)
    mkSqlType (#const SQLT_NUM)       _    prec scale = SqlNumeric (fromIntegral prec) (fromIntegral scale)
    mkSqlType (#const SQLT_INT)       _    _    _     = SqlInteger
    mkSqlType (#const SQLT_FLT)       _    _    _     = SqlFloat
    mkSqlType (#const SQLT_DATE)      _    _    _     = SqlDate
    mkSqlType (#const SQLT_TIME)      _    _    _     = SqlTime
    mkSqlType (#const SQLT_TIME_TZ)   _    _    _     = SqlTimeTZ
    mkSqlType (#const SQLT_TIMESTAMP) _    _    _     = SqlTimeStamp
    mkSqlType (#const SQLT_LNG)       _    _    _     = SqlText
    mkSqlType tp                      _    _    _     = SqlUnknown (fromIntegral tp)

    tables env svcCtx err =
      withCStringLen "COREDB_SYSTEM" $ \(cstr,clen) ->
      alloca $ \pDescr ->
      alloca $ \pParam ->
      alloca $ \pColl -> do
        ociHandleAlloc env pDescr (#const OCI_HTYPE_DESCRIBE) 0 nullPtr >>= handleSqlResult err
        descr <- peek pDescr
        ociDescribeAny svcCtx err (castPtr cstr) (fromIntegral clen) (#const OCI_OTYPE_NAME) (#const OCI_DEFAULT) (#const OCI_PTYPE_SCHEMA) descr >>= handleSqlResult err
        ociAttrGet descr (#const OCI_HTYPE_DESCRIBE) pParam nullPtr (#const OCI_ATTR_PARAM) err >>= handleSqlResult err
        param <- peek pParam
        ociAttrGet param (#const OCI_DTYPE_PARAM) pColl nullPtr (#const OCI_ATTR_LIST_OBJECTS) err >>= handleSqlResult err
        coll <- peek pColl
        names <- allocaBytes (#const (sizeof(FIELD_DEF))) (getTableNames coll 1)
        ociDescriptorFree coll  (#const OCI_DTYPE_PARAM)
        ociDescriptorFree param (#const OCI_DTYPE_PARAM)
        ociHandleFree descr (#const OCI_HTYPE_DESCRIBE) >>= handleSqlResult err
        return names
      where
        getTableNames coll index buffer = do
          res <- ociParamGet coll (#const OCI_DTYPE_PARAM) err ((#ptr FIELD_DEF, par) buffer) index
          par <- (#peek FIELD_DEF, par) buffer
          if res == (#const OCI_SUCCESS)
            then do
              ociAttrGet par (#const OCI_DTYPE_PARAM) ((#ptr FIELD_DEF, colName) buffer) ((#ptr FIELD_DEF, colNameLen) buffer) (#const OCI_ATTR_OBJ_NAME) err >>= handleSqlResult err
              pName <- (#peek FIELD_DEF, colName) buffer
              nameLen <- (#peek FIELD_DEF, colNameLen) buffer
              name <- peekCStringLen (pName, fromIntegral (nameLen :: (#type ub4)))
              ociAttrGet par (#const OCI_DTYPE_PARAM) ((#ptr FIELD_DEF, dtype) buffer) nullPtr (#const OCI_ATTR_PTYPE) err >>= handleSqlResult err
              ptype <- (#peek FIELD_DEF, dtype) buffer
              ociDescriptorFree par (#const OCI_DTYPE_PARAM)
              names <- getTableNames coll (index+1) buffer
              return $! (if ptype == ((#const OCI_PTYPE_TABLE) :: (#type ub1))
                           then name:names
                           else      names)
            else return []

    describe env svcCtx err tblName =
      withCStringLen tblName $ \(cstr,clen) ->
      alloca $ \pDescr ->
      alloca $ \pParam ->
      alloca $ \pColl ->
      alloca $ \pNumcols -> do
        ociHandleAlloc env pDescr (#const OCI_HTYPE_DESCRIBE) 0 nullPtr >>= handleSqlResult err
        descr <- peek pDescr
        ociDescribeAny svcCtx err (castPtr cstr) (fromIntegral clen) (#const OCI_OTYPE_NAME) (#const OCI_DEFAULT) (#const OCI_PTYPE_TABLE) descr >>= handleSqlResult err
        ociAttrGet descr (#const OCI_HTYPE_DESCRIBE) pParam nullPtr (#const OCI_ATTR_PARAM) err >>= handleSqlResult err
        param <- peek pParam
        ociAttrGet param (#const OCI_DTYPE_PARAM) pNumcols nullPtr (#const OCI_ATTR_NUM_COLS) err >>= handleSqlResult err
        numcols <- peek (pNumcols :: Ptr (#type ub2))
        ociAttrGet param (#const OCI_DTYPE_PARAM) pColl nullPtr (#const OCI_ATTR_LIST_COLUMNS) err >>= handleSqlResult err
        coll <- peek pColl
        fieldDefs <- allocaBytes (#const (sizeof(FIELD_DEF))) (getFieldDefs coll 1 (fromIntegral numcols))
        ociDescriptorFree coll  (#const OCI_DTYPE_PARAM)
        ociDescriptorFree param (#const OCI_DTYPE_PARAM)
        ociHandleFree descr (#const OCI_HTYPE_DESCRIBE) >>= handleSqlResult err
        return fieldDefs
      where
        getFieldDefs coll index numcols buffer
          | index <= numcols = do
               ociParamGet coll (#const OCI_DTYPE_PARAM) err ((#ptr FIELD_DEF, par) buffer) index >>= handleSqlResult err
               fieldDef  <- getFieldDef err buffer
               fieldDefs <- getFieldDefs coll (index+1) numcols buffer
               return (fieldDef:fieldDefs)
          | otherwise = return []

    getFieldDef err buffer = do
      par <- (#peek FIELD_DEF, par) buffer
      ociAttrGet par (#const OCI_DTYPE_PARAM) ((#ptr FIELD_DEF, dtype) buffer) nullPtr (#const OCI_ATTR_DATA_TYPE) err >>= handleSqlResult err
      dtype <- (#peek FIELD_DEF, dtype) buffer
      ociAttrGet par (#const OCI_DTYPE_PARAM) ((#ptr FIELD_DEF, dsize) buffer) nullPtr (#const OCI_ATTR_DATA_SIZE) err >>= handleSqlResult err
      dsize <- (#peek FIELD_DEF, dsize) buffer
      ociAttrGet par (#const OCI_DTYPE_PARAM) ((#ptr FIELD_DEF, dprec) buffer) nullPtr (#const OCI_ATTR_PRECISION) err >>= handleSqlResult err
      dprec <- (#peek FIELD_DEF, dprec) buffer
      ociAttrGet par (#const OCI_DTYPE_PARAM) ((#ptr FIELD_DEF, dscale) buffer) nullPtr (#const OCI_ATTR_SCALE) err >>= handleSqlResult err
      dscale <- (#peek FIELD_DEF, dscale) buffer
      ociAttrGet par (#const OCI_DTYPE_PARAM) ((#ptr FIELD_DEF, isNull) buffer) nullPtr (#const OCI_ATTR_IS_NULL) err >>= handleSqlResult err
      isNull <- (#peek FIELD_DEF, isNull) buffer
      ociAttrGet par (#const OCI_DTYPE_PARAM) ((#ptr FIELD_DEF, colName) buffer) ((#ptr FIELD_DEF, colNameLen) buffer) (#const OCI_ATTR_NAME) err >>= handleSqlResult err
      pColName <- (#peek FIELD_DEF, colName) buffer
      colNameLen <- (#peek FIELD_DEF, colNameLen) buffer
      colName <- peekCStringLen (pColName, fromIntegral (colNameLen :: (#type ub4)))
      ociDescriptorFree par (#const OCI_DTYPE_PARAM)
      return (colName,mkSqlType dtype dsize dprec dscale,toBool (fromIntegral (isNull :: (#type ub1))))
            
    beginTransaction myEnvironment svcCtx err =
      ociTransStart svcCtx err 0 (#const OCI_TRANS_READWRITE) >>= handleSqlResult err
    
    commitTransaction myEnvironment svcCtx err =
      ociTransCommit svcCtx err (#const OCI_DEFAULT) >>= handleSqlResult err
    
    rollbackTransaction myEnvironment svcCtx err = do
      ociTransRollback svcCtx err (#const OCI_DEFAULT) >>= handleSqlResult err

    closeStatement stmt buffer err = do
      ociHandleFree stmt (#const OCI_HTYPE_STMT) >>= handleSqlResult err
      free buffer

    fetch stmt err = do
      res <- ociStmtFetch2 stmt err 1 (#const OCI_FETCH_NEXT) 0 (#const OCI_DEFAULT)
      handleSqlResult err res
      return (res /= (#const OCI_NO_DATA))

    getColValue :: Ptr () -> Int -> FieldDef -> (FieldDef -> CString -> Int -> IO a) -> IO a
    getColValue buffer colNumber fieldDef f = do
      offset <- peek (castPtr buffer `advancePtr` colNumber)
      let valuePtr = castPtr buffer `plusPtr` fromIntegral (offset :: CInt)
      valueLen <- strlen valuePtr
      f fieldDef valuePtr (fromIntegral valueLen)

Theme by Vikram Singh | Powered by WebSVN v2.3.3