?prevdifflink? - Blame
-----------------------------------------------------------------------------------------
{-| Module : Database.HSQL.Oracle
Copyright : (c) Krasimir Angelov 2003
License : BSD-style
Maintainer : kr.angelov@gmail.com
Stability : provisional
Portability : portable
The module provides interface to Oracle
-}
-----------------------------------------------------------------------------------------
module Database.HSQL.Oracle(connect, module Database.HSQL) where
import Database.HSQL
import Database.HSQL.Types
import Foreign
import Foreign.C
import Foreign.Concurrent as FC
import Control.Concurrent.MVar
import Control.Exception(throwDyn)
import Data.Word
#include <HsOCI.h>
type OCIHandle = Ptr ()
type OCIEnv = OCIHandle
type OCIError = OCIHandle
type OCISvcCtx = OCIHandle
type OCIStmt = OCIHandle
type OCIParam = OCIHandle
type OCIDefine = OCIHandle
type OCIDescribe=OCIHandle
type OCIEnvRef = ForeignPtr ()
foreign import ccall "OCIEnvCreate" ociEnvCreate :: Ptr OCIEnv -> CInt -> Ptr a -> FunPtr a -> FunPtr a -> FunPtr a -> CInt -> Ptr (Ptr a) -> IO CInt
foreign import ccall "OCITerminate" ociTerminate :: CInt -> IO CInt
foreign import ccall "OCIHandleAlloc" ociHandleAlloc :: OCIHandle -> Ptr OCIHandle -> CInt -> CInt -> Ptr a -> IO CInt
foreign import ccall "OCIHandleFree" ociHandleFree :: OCIHandle -> CInt -> IO CInt
foreign import ccall "OCIErrorGet" ociErrorGet :: OCIHandle -> CInt -> CString -> Ptr CInt -> CString -> CInt -> CInt -> IO CInt
foreign import ccall "OCILogon" ociLogon :: OCIEnv -> OCIError -> Ptr OCISvcCtx -> CString -> CInt -> CString -> CInt -> CString -> CInt -> IO CInt
foreign import ccall "OCILogoff" ociLogoff :: OCISvcCtx -> OCIError -> IO CInt
foreign import ccall "OCIStmtPrepare" ociStmtPrepare :: OCIStmt -> OCIError -> CString -> CInt -> CInt -> CInt -> IO CInt
foreign import ccall "OCIStmtExecute" ociStmtExecute :: OCISvcCtx -> OCIStmt -> OCIError -> CInt -> CInt -> OCIHandle -> OCIHandle -> CInt -> IO CInt
foreign import ccall "OCIStmtFetch2" ociStmtFetch2 :: OCIStmt -> OCIError -> CInt -> CInt -> CInt -> CInt -> IO CInt
foreign import ccall "OCIDefineByPos" ociDefineByPos :: OCIStmt -> Ptr OCIDefine -> OCIError -> CInt -> Ptr a -> CInt -> CShort -> Ptr CShort -> Ptr CShort -> Ptr CShort -> CInt -> IO CInt
foreign import ccall "OCIParamGet" ociParamGet :: OCIHandle -> CInt -> OCIError -> Ptr OCIParam -> CInt -> IO CInt
foreign import ccall "OCIAttrGet" ociAttrGet :: OCIParam -> CInt -> Ptr a -> Ptr CInt -> CInt -> OCIError -> IO CInt
foreign import ccall "OCIDescribeAny" ociDescribeAny :: OCISvcCtx -> OCIError -> Ptr () -> CInt -> Word8 -> Word8 -> Word8 -> OCIDescribe -> IO CInt
foreign import ccall "OCIDescriptorFree" ociDescriptorFree :: OCIHandle -> CInt -> IO CInt
foreign import ccall "OCITransStart" ociTransStart :: OCISvcCtx -> OCIError -> Word8 -> CInt -> IO CInt
foreign import ccall "OCITransCommit" ociTransCommit :: OCISvcCtx -> OCIError -> CInt -> IO CInt
foreign import ccall "OCITransRollback" ociTransRollback :: OCISvcCtx -> OCIError -> CInt -> IO CInt
foreign import ccall "strlen" strlen :: CString -> IO CInt
-----------------------------------------------------------------------------------------
-- keeper of OCIEnv
-----------------------------------------------------------------------------------------
{-# NOINLINE myEnvironment #-}
myEnvironment :: OCIEnvRef
myEnvironment = unsafePerformIO $ alloca $ \ pOCIEnv -> do
ociEnvCreate pOCIEnv (#const OCI_DEFAULT) nullPtr nullFunPtr nullFunPtr nullFunPtr 0 nullPtr >>= handleSqlResult nullPtr
env <- peek pOCIEnv
FC.newForeignPtr env terminate
where
terminate = ociTerminate (#const OCI_DEFAULT) >>= handleSqlResult nullPtr
-----------------------------------------------------------------------------------------
-- error handling
-----------------------------------------------------------------------------------------
handleSqlResult err res
| res == (#const OCI_SUCCESS) || res == (#const OCI_NO_DATA) = return ()
| res == (#const OCI_SUCCESS_WITH_INFO) = do
#ifdef DEBUG
e <- getSqlError
putTraceMsg (show e)
#else
return ()
#endif
| res == (#const OCI_INVALID_HANDLE) = throwDyn SqlInvalidHandle
| res == (#const OCI_STILL_EXECUTING) = throwDyn SqlStillExecuting
| res == (#const OCI_NEED_DATA) = throwDyn SqlNeedData
| res == (#const OCI_ERROR) = getSqlError >>= throwDyn
| otherwise = error (show res)
where
stringBufferLen = 1024
getSqlError =
alloca $ \pErrCode ->
allocaBytes stringBufferLen $ \pErrMsg -> do
rc <- ociErrorGet err 1 nullPtr pErrCode pErrMsg (fromIntegral stringBufferLen) (#const OCI_HTYPE_ERROR)
if rc < 0
then return SqlNoData
else do
msg <- peekCString pErrMsg
errCode <- peek pErrCode
return (SqlError {seState="", seNativeError=fromIntegral errCode, seErrorMsg=msg})
-- | Makes a new connection to the Oracle service
connect :: String -- ^ Service name
-> String -- ^ User identifier
-> String -- ^ Password
-> IO Connection -- ^ the returned value represents the new connection
connect service user pwd =
withForeignPtr myEnvironment $ \env ->
withCStringLen user $ \(user, user_len) ->
withCStringLen pwd $ \(pwd, pwd_len) ->
withCStringLen service $ \(service, service_len) ->
alloca $ \pError -> do
alloca $ \pSvcCtx -> do
ociHandleAlloc env pError (#const OCI_HTYPE_ERROR) 0 nullPtr >>= handleSqlResult nullPtr
err <- peek pError
res <- ociLogon env err pSvcCtx user (fromIntegral user_len) pwd (fromIntegral pwd_len) service (fromIntegral service_len)
handleSqlResult err res
svcCtx <- peek pSvcCtx
refFalse <- newMVar False
let connection = (Connection
{ connDisconnect = disconnect svcCtx err
, connExecute = execute myEnvironment svcCtx err
, connQuery = query connection myEnvironment svcCtx err
, connTables = tables env svcCtx err
, connDescribe = describe env svcCtx err
, connBeginTransaction = beginTransaction myEnvironment svcCtx err
, connCommitTransaction = commitTransaction myEnvironment svcCtx err
, connRollbackTransaction = rollbackTransaction myEnvironment svcCtx err
, connClosed = refFalse
})
return connection
where
disconnect svcCtx err = do
ociLogoff svcCtx err >>= handleSqlResult err
ociHandleFree err (#const OCI_HTYPE_ERROR) >>= handleSqlResult err
execute envRef svcCtx err query =
withForeignPtr envRef $ \env ->
withCStringLen query $ \(query,query_len) ->
alloca $ \pStmt -> do
ociHandleAlloc env pStmt (#const OCI_HTYPE_STMT) 0 nullPtr >>= handleSqlResult err
stmt <- peek pStmt
ociStmtPrepare stmt err query (fromIntegral query_len) (#const OCI_NTV_SYNTAX) (#const OCI_DEFAULT) >>= handleSqlResult err
ociStmtExecute svcCtx stmt err 1 0 nullPtr nullPtr (#const OCI_DEFAULT) >>= handleSqlResult err
ociHandleFree stmt (#const OCI_HTYPE_STMT) >>= handleSqlResult err
query connection envRef svcCtx err query =
withForeignPtr envRef $ \env ->
withCStringLen query $ \(query,query_len) ->
alloca $ \pStmt -> do
ociHandleAlloc env pStmt (#const OCI_HTYPE_STMT) 0 nullPtr >>= handleSqlResult err
stmt <- peek pStmt
ociStmtPrepare stmt err query (fromIntegral query_len) (#const OCI_NTV_SYNTAX) (#const OCI_DEFAULT) >>= handleSqlResult err
ociStmtExecute svcCtx stmt err 0 0 nullPtr nullPtr (#const OCI_DEFAULT) >>= handleSqlResult err
fields <- allocaBytes (#const (sizeof(FIELD_DEF))) (getFieldDefs stmt 1)
let offsets_arr_size = fromIntegral (length fields * sizeOf offsets_arr_size) :: CInt
buffer <- mallocBytes (fromIntegral (foldr ((+) . sqlType2Size) offsets_arr_size fields))
definePositions stmt err buffer 0 offsets_arr_size fields
refFalse <- newMVar False
let statement = Statement
{ stmtConn = connection
, stmtClose = closeStatement stmt buffer err
, stmtFetch = fetch stmt err
, stmtGetCol = getColValue buffer
, stmtFields = fields
, stmtClosed = refFalse
}
return statement
where
getFieldDefs stmt counter buffer = do
res <- ociParamGet stmt (#const OCI_HTYPE_STMT) err ((#ptr FIELD_DEF, par) buffer) counter
if res == (#const OCI_SUCCESS)
then do field <- getFieldDef err buffer
fields <- getFieldDefs stmt (counter+1) buffer
return (field:fields)
else return []
sqlType2Size :: FieldDef -> CInt
sqlType2Size (_,tp,_) =
case tp of
SqlVarChar n -> fromIntegral n+1
SqlNumeric p s -> fromIntegral (p+s+3) -- The value precision plus optional positions for '.', '-' and
-- one position for the '\0' character at end of the string.
SqlInteger -> 16 -- 12 digits are enough (maxBound :: Int) has 10 digits.
-- in addition we may need one position for '-' and one
-- for the '\0' character at end of the string.
SqlFloat -> 100
SqlDate -> 100
SqlTime -> 100
SqlTimeTZ -> 100
SqlTimeStamp -> 100
SqlText -> 100
SqlUnknown _ -> 0
definePositions stmt err buffer pos offset [] = return ()
definePositions stmt err buffer pos offset (field:fields) =
alloca $ \pDef -> do
let size = sqlType2Size field
poke (castPtr buffer `advancePtr` fromIntegral pos) offset
ociDefineByPos stmt pDef err (pos+1) (buffer `plusPtr` fromIntegral offset) size (#const SQLT_STR) nullPtr nullPtr nullPtr (#const OCI_DEFAULT)
definePositions stmt err buffer (pos+1) (offset+size) fields
mkSqlType :: (#type OCITypeCode) -> (#type ub2) -> (#type ub1) -> (#type ub1) -> SqlType
mkSqlType (#const SQLT_CHR) size _ _ = SqlVarChar (fromIntegral size)
mkSqlType (#const SQLT_AFC) size _ _ = SqlChar (fromIntegral size)
mkSqlType (#const SQLT_NUM) _ prec scale = SqlNumeric (fromIntegral prec) (fromIntegral scale)
mkSqlType (#const SQLT_INT) _ _ _ = SqlInteger
mkSqlType (#const SQLT_FLT) _ _ _ = SqlFloat
mkSqlType (#const SQLT_DATE) _ _ _ = SqlDate
mkSqlType (#const SQLT_TIME) _ _ _ = SqlTime
mkSqlType (#const SQLT_TIME_TZ) _ _ _ = SqlTimeTZ
mkSqlType (#const SQLT_TIMESTAMP) _ _ _ = SqlTimeStamp
mkSqlType (#const SQLT_LNG) _ _ _ = SqlText
mkSqlType tp _ _ _ = SqlUnknown (fromIntegral tp)
tables env svcCtx err =
withCStringLen "COREDB_SYSTEM" $ \(cstr,clen) ->
alloca $ \pDescr ->
alloca $ \pParam ->
alloca $ \pColl -> do
ociHandleAlloc env pDescr (#const OCI_HTYPE_DESCRIBE) 0 nullPtr >>= handleSqlResult err
descr <- peek pDescr
ociDescribeAny svcCtx err (castPtr cstr) (fromIntegral clen) (#const OCI_OTYPE_NAME) (#const OCI_DEFAULT) (#const OCI_PTYPE_SCHEMA) descr >>= handleSqlResult err
ociAttrGet descr (#const OCI_HTYPE_DESCRIBE) pParam nullPtr (#const OCI_ATTR_PARAM) err >>= handleSqlResult err
param <- peek pParam
ociAttrGet param (#const OCI_DTYPE_PARAM) pColl nullPtr (#const OCI_ATTR_LIST_OBJECTS) err >>= handleSqlResult err
coll <- peek pColl
names <- allocaBytes (#const (sizeof(FIELD_DEF))) (getTableNames coll 1)
ociDescriptorFree coll (#const OCI_DTYPE_PARAM)
ociDescriptorFree param (#const OCI_DTYPE_PARAM)
ociHandleFree descr (#const OCI_HTYPE_DESCRIBE) >>= handleSqlResult err
return names
where
getTableNames coll index buffer = do
res <- ociParamGet coll (#const OCI_DTYPE_PARAM) err ((#ptr FIELD_DEF, par) buffer) index
par <- (#peek FIELD_DEF, par) buffer
if res == (#const OCI_SUCCESS)
then do
ociAttrGet par (#const OCI_DTYPE_PARAM) ((#ptr FIELD_DEF, colName) buffer) ((#ptr FIELD_DEF, colNameLen) buffer) (#const OCI_ATTR_OBJ_NAME) err >>= handleSqlResult err
pName <- (#peek FIELD_DEF, colName) buffer
nameLen <- (#peek FIELD_DEF, colNameLen) buffer
name <- peekCStringLen (pName, fromIntegral (nameLen :: (#type ub4)))
ociAttrGet par (#const OCI_DTYPE_PARAM) ((#ptr FIELD_DEF, dtype) buffer) nullPtr (#const OCI_ATTR_PTYPE) err >>= handleSqlResult err
ptype <- (#peek FIELD_DEF, dtype) buffer
ociDescriptorFree par (#const OCI_DTYPE_PARAM)
names <- getTableNames coll (index+1) buffer
return $! (if ptype == ((#const OCI_PTYPE_TABLE) :: (#type ub1))
then name:names
else names)
else return []
describe env svcCtx err tblName =
withCStringLen tblName $ \(cstr,clen) ->
alloca $ \pDescr ->
alloca $ \pParam ->
alloca $ \pColl ->
alloca $ \pNumcols -> do
ociHandleAlloc env pDescr (#const OCI_HTYPE_DESCRIBE) 0 nullPtr >>= handleSqlResult err
descr <- peek pDescr
ociDescribeAny svcCtx err (castPtr cstr) (fromIntegral clen) (#const OCI_OTYPE_NAME) (#const OCI_DEFAULT) (#const OCI_PTYPE_TABLE) descr >>= handleSqlResult err
ociAttrGet descr (#const OCI_HTYPE_DESCRIBE) pParam nullPtr (#const OCI_ATTR_PARAM) err >>= handleSqlResult err
param <- peek pParam
ociAttrGet param (#const OCI_DTYPE_PARAM) pNumcols nullPtr (#const OCI_ATTR_NUM_COLS) err >>= handleSqlResult err
numcols <- peek (pNumcols :: Ptr (#type ub2))
ociAttrGet param (#const OCI_DTYPE_PARAM) pColl nullPtr (#const OCI_ATTR_LIST_COLUMNS) err >>= handleSqlResult err
coll <- peek pColl
fieldDefs <- allocaBytes (#const (sizeof(FIELD_DEF))) (getFieldDefs coll 1 (fromIntegral numcols))
ociDescriptorFree coll (#const OCI_DTYPE_PARAM)
ociDescriptorFree param (#const OCI_DTYPE_PARAM)
ociHandleFree descr (#const OCI_HTYPE_DESCRIBE) >>= handleSqlResult err
return fieldDefs
where
getFieldDefs coll index numcols buffer
| index <= numcols = do
ociParamGet coll (#const OCI_DTYPE_PARAM) err ((#ptr FIELD_DEF, par) buffer) index >>= handleSqlResult err
fieldDef <- getFieldDef err buffer
fieldDefs <- getFieldDefs coll (index+1) numcols buffer
return (fieldDef:fieldDefs)
| otherwise = return []
getFieldDef err buffer = do
par <- (#peek FIELD_DEF, par) buffer
ociAttrGet par (#const OCI_DTYPE_PARAM) ((#ptr FIELD_DEF, dtype) buffer) nullPtr (#const OCI_ATTR_DATA_TYPE) err >>= handleSqlResult err
dtype <- (#peek FIELD_DEF, dtype) buffer
ociAttrGet par (#const OCI_DTYPE_PARAM) ((#ptr FIELD_DEF, dsize) buffer) nullPtr (#const OCI_ATTR_DATA_SIZE) err >>= handleSqlResult err
dsize <- (#peek FIELD_DEF, dsize) buffer
ociAttrGet par (#const OCI_DTYPE_PARAM) ((#ptr FIELD_DEF, dprec) buffer) nullPtr (#const OCI_ATTR_PRECISION) err >>= handleSqlResult err
dprec <- (#peek FIELD_DEF, dprec) buffer
ociAttrGet par (#const OCI_DTYPE_PARAM) ((#ptr FIELD_DEF, dscale) buffer) nullPtr (#const OCI_ATTR_SCALE) err >>= handleSqlResult err
dscale <- (#peek FIELD_DEF, dscale) buffer
ociAttrGet par (#const OCI_DTYPE_PARAM) ((#ptr FIELD_DEF, isNull) buffer) nullPtr (#const OCI_ATTR_IS_NULL) err >>= handleSqlResult err
isNull <- (#peek FIELD_DEF, isNull) buffer
ociAttrGet par (#const OCI_DTYPE_PARAM) ((#ptr FIELD_DEF, colName) buffer) ((#ptr FIELD_DEF, colNameLen) buffer) (#const OCI_ATTR_NAME) err >>= handleSqlResult err
pColName <- (#peek FIELD_DEF, colName) buffer
colNameLen <- (#peek FIELD_DEF, colNameLen) buffer
colName <- peekCStringLen (pColName, fromIntegral (colNameLen :: (#type ub4)))
ociDescriptorFree par (#const OCI_DTYPE_PARAM)
return (colName,mkSqlType dtype dsize dprec dscale,toBool (fromIntegral (isNull :: (#type ub1))))
beginTransaction myEnvironment svcCtx err =
ociTransStart svcCtx err 0 (#const OCI_TRANS_READWRITE) >>= handleSqlResult err
commitTransaction myEnvironment svcCtx err =
ociTransCommit svcCtx err (#const OCI_DEFAULT) >>= handleSqlResult err
rollbackTransaction myEnvironment svcCtx err = do
ociTransRollback svcCtx err (#const OCI_DEFAULT) >>= handleSqlResult err
closeStatement stmt buffer err = do
ociHandleFree stmt (#const OCI_HTYPE_STMT) >>= handleSqlResult err
free buffer
fetch stmt err = do
res <- ociStmtFetch2 stmt err 1 (#const OCI_FETCH_NEXT) 0 (#const OCI_DEFAULT)
handleSqlResult err res
return (res /= (#const OCI_NO_DATA))
getColValue :: Ptr () -> Int -> FieldDef -> (FieldDef -> CString -> Int -> IO a) -> IO a
getColValue buffer colNumber fieldDef f = do
offset <- peek (castPtr buffer `advancePtr` colNumber)
let valuePtr = castPtr buffer `plusPtr` fromIntegral (offset :: CInt)
valueLen <- strlen valuePtr
f fieldDef valuePtr (fromIntegral valueLen)
|