?prevdifflink? - Blame
-----------------------------------------------------------------------------------------
{-| Module : Database.HSQL.PostgreSQL
Copyright : (c) Krasimir Angelov 2003
License : BSD-style
Maintainer : ka2_mail@yahoo.com
Stability : provisional
Portability : portable
The module provides interface to PostgreSQL database
-}
-----------------------------------------------------------------------------------------
module Database.HSQL.PostgreSQL(connect, module Database.HSQL) where
import Database.HSQL
import Database.HSQL.Types
import Data.Dynamic
import Data.Char
import Foreign
import Foreign.C
import Control.Exception (throwDyn, catchDyn, dynExceptions, Exception(..))
import Control.Monad(when,unless,mplus)
import Control.Concurrent.MVar
import System.Time
import System.IO.Unsafe
import Text.ParserCombinators.ReadP
import Text.Read
import Numeric
# include <time.h>
#include <libpq-fe.h>
#include <postgres.h>
#include <catalog/pg_type.h>
type PGconn = Ptr ()
type PGresult = Ptr ()
type ConnStatusType = #type ConnStatusType
type ExecStatusType = #type ExecStatusType
type Oid = #type Oid
foreign import ccall "libpq-fe.h PQsetdbLogin" pqSetdbLogin :: CString -> CString -> CString -> CString -> CString -> CString -> CString -> IO PGconn
foreign import ccall "libpq-fe.h PQstatus" pqStatus :: PGconn -> IO ConnStatusType
foreign import ccall "libpq-fe.h PQerrorMessage" pqErrorMessage :: PGconn -> IO CString
foreign import ccall "libpq-fe.h PQfinish" pqFinish :: PGconn -> IO ()
foreign import ccall "libpq-fe.h PQexec" pqExec :: PGconn -> CString -> IO PGresult
foreign import ccall "libpq-fe.h PQresultStatus" pqResultStatus :: PGresult -> IO ExecStatusType
foreign import ccall "libpq-fe.h PQresStatus" pqResStatus :: ExecStatusType -> IO CString
foreign import ccall "libpq-fe.h PQresultErrorMessage" pqResultErrorMessage :: PGresult -> IO CString
foreign import ccall "libpq-fe.h PQnfields" pgNFields :: PGresult -> IO Int
foreign import ccall "libpq-fe.h PQntuples" pqNTuples :: PGresult -> IO Int
foreign import ccall "libpq-fe.h PQfname" pgFName :: PGresult -> Int -> IO CString
foreign import ccall "libpq-fe.h PQftype" pqFType :: PGresult -> Int -> IO Oid
foreign import ccall "libpq-fe.h PQfmod" pqFMod :: PGresult -> Int -> IO Int
foreign import ccall "libpq-fe.h PQfnumber" pqFNumber :: PGresult -> CString -> IO Int
foreign import ccall "libpq-fe.h PQgetvalue" pqGetvalue :: PGresult -> Int -> Int -> IO CString
foreign import ccall "libpq-fe.h PQgetisnull" pqGetisnull :: PGresult -> Int -> Int -> IO Int
foreign import ccall "strlen" strlen :: CString -> IO Int
-----------------------------------------------------------------------------------------
-- Connect/Disconnect
-----------------------------------------------------------------------------------------
-- | Makes a new connection to the database server.
connect :: String -- ^ Server name
-> String -- ^ Database name
-> String -- ^ User identifier
-> String -- ^ Authentication string (password)
-> IO Connection
connect server database user authentication = do
pServer <- newCString server
pDatabase <- newCString database
pUser <- newCString user
pAuthentication <- newCString authentication
pConn <- pqSetdbLogin pServer nullPtr nullPtr nullPtr pDatabase pUser pAuthentication
free pServer
free pUser
free pAuthentication
status <- pqStatus pConn
unless (status == (#const CONNECTION_OK)) (do
errMsg <- pqErrorMessage pConn >>= peekCString
pqFinish pConn
throwDyn (SqlError {seState="C", seNativeError=fromIntegral status, seErrorMsg=errMsg}))
refFalse <- newMVar False
let connection = Connection
{ connDisconnect = pqFinish pConn
, connExecute = execute pConn
, connQuery = query connection pConn
, connTables = tables connection pConn
, connDescribe = describe connection pConn
, connBeginTransaction = execute pConn "begin"
, connCommitTransaction = execute pConn "commit"
, connRollbackTransaction = execute pConn "rollback"
, connClosed = refFalse
}
return connection
where
execute :: PGconn -> String -> IO ()
execute pConn sqlExpr = do
pRes <- withCString sqlExpr (pqExec pConn)
when (pRes==nullPtr) (do
errMsg <- pqErrorMessage pConn >>= peekCString
throwDyn (SqlError {seState="E", seNativeError=(#const PGRES_FATAL_ERROR), seErrorMsg=errMsg}))
status <- pqResultStatus pRes
unless (status == (#const PGRES_COMMAND_OK) || status == (#const PGRES_TUPLES_OK)) (do
errMsg <- pqResultErrorMessage pRes >>= peekCString
throwDyn (SqlError {seState="E", seNativeError=fromIntegral status, seErrorMsg=errMsg}))
return ()
query :: Connection -> PGconn -> String -> IO Statement
query conn pConn query = do
pRes <- withCString query (pqExec pConn)
when (pRes==nullPtr) (do
errMsg <- pqErrorMessage pConn >>= peekCString
throwDyn (SqlError {seState="E", seNativeError=(#const PGRES_FATAL_ERROR), seErrorMsg=errMsg}))
status <- pqResultStatus pRes
unless (status == (#const PGRES_COMMAND_OK) || status == (#const PGRES_TUPLES_OK)) (do
errMsg <- pqResultErrorMessage pRes >>= peekCString
throwDyn (SqlError {seState="E", seNativeError=fromIntegral status, seErrorMsg=errMsg}))
defs <- if status == (#const PGRES_TUPLES_OK) then pgNFields pRes >>= getFieldDefs pRes 0 else return []
countTuples <- pqNTuples pRes;
tupleIndex <- newMVar (-1)
refFalse <- newMVar False
return (Statement
{ stmtConn = conn
, stmtClose = return ()
, stmtFetch = fetch tupleIndex countTuples
, stmtGetCol = getColValue pRes tupleIndex countTuples
, stmtFields = defs
, stmtClosed = refFalse
})
where
getFieldDefs pRes i n
| i >= n = return []
| otherwise = do
name <- pgFName pRes i >>= peekCString
dataType <- pqFType pRes i
modifier <- pqFMod pRes i
defs <- getFieldDefs pRes (i+1) n
return ((name,mkSqlType dataType modifier,True):defs)
mkSqlType :: Oid -> Int -> SqlType
mkSqlType (#const BPCHAROID) size = SqlChar (size-4)
mkSqlType (#const VARCHAROID) size = SqlVarChar (size-4)
mkSqlType (#const NAMEOID) size = SqlVarChar 31
mkSqlType (#const TEXTOID) size = SqlText
mkSqlType (#const NUMERICOID) size = SqlNumeric ((size-4) `div` 0x10000) ((size-4) `mod` 0x10000)
mkSqlType (#const INT2OID) size = SqlSmallInt
mkSqlType (#const INT4OID) size = SqlInteger
mkSqlType (#const FLOAT4OID) size = SqlReal
mkSqlType (#const FLOAT8OID) size = SqlDouble
mkSqlType (#const BOOLOID) size = SqlBit
mkSqlType (#const BITOID) size = SqlBinary size
mkSqlType (#const VARBITOID) size = SqlVarBinary size
mkSqlType (#const BYTEAOID) size = SqlTinyInt
mkSqlType (#const INT8OID) size = SqlBigInt
mkSqlType (#const DATEOID) size = SqlDate
mkSqlType (#const TIMEOID) size = SqlTime
mkSqlType (#const TIMETZOID) size = SqlTimeTZ
mkSqlType (#const ABSTIMEOID) size = SqlAbsTime
mkSqlType (#const RELTIMEOID) size = SqlRelTime
mkSqlType (#const INTERVALOID) size = SqlTimeInterval
mkSqlType (#const TINTERVALOID) size = SqlAbsTimeInterval
mkSqlType (#const TIMESTAMPOID) size = SqlDateTime
mkSqlType (#const TIMESTAMPTZOID) size = SqlDateTimeTZ
mkSqlType (#const CASHOID) size = SqlMoney
mkSqlType (#const INETOID) size = SqlINetAddr
mkSqlType (#const 829) size = SqlMacAddr -- hack
mkSqlType (#const CIDROID) size = SqlCIDRAddr
mkSqlType (#const POINTOID) size = SqlPoint
mkSqlType (#const LSEGOID) size = SqlLSeg
mkSqlType (#const PATHOID) size = SqlPath
mkSqlType (#const BOXOID) size = SqlBox
mkSqlType (#const POLYGONOID) size = SqlPolygon
mkSqlType (#const LINEOID) size = SqlLine
mkSqlType (#const CIRCLEOID) size = SqlCircle
mkSqlType tp size = SqlUnknown (fromIntegral tp)
getFieldValue stmt colNumber fieldDef v = do
mb_v <- stmtGetCol stmt colNumber fieldDef fromSqlCStringLen
return (case mb_v of { Nothing -> v; Just a -> a })
tables :: Connection -> PGconn -> IO [String]
tables connection pConn = do
stmt <- query connection pConn "select relname from pg_class where relkind='r' and relname !~ '^pg_'"
collectRows (\s -> getFieldValue s 0 ("relname", SqlVarChar 0, False) "") stmt
describe :: Connection -> PGconn -> String -> IO [FieldDef]
describe connection pConn table = do
stmt <- query connection pConn
("select attname, atttypid, atttypmod, attnotnull " ++
"from pg_attribute as cols join pg_class as ts on cols.attrelid=ts.oid " ++
"where cols.attnum > 0 and ts.relname="++toSqlValue table++
" and cols.attisdropped = False ")
collectRows getColumnInfo stmt
where
getColumnInfo stmt = do
column_name <- getFieldValue stmt 0 ("attname", SqlVarChar 0, False) ""
data_type <- getFieldValue stmt 1 ("atttypid", SqlInteger, False) 0
type_mod <- getFieldValue stmt 2 ("atttypmod", SqlInteger, False) 0
notnull <- getFieldValue stmt 3 ("attnotnull", SqlBit, False) False
let sqlType = mkSqlType (fromIntegral (data_type :: Int)) (fromIntegral (type_mod :: Int))
return (column_name, sqlType, not notnull)
fetch :: MVar Int -> Int -> IO Bool
fetch tupleIndex countTuples =
modifyMVar tupleIndex (\index -> return (index+1,index < countTuples-1))
getColValue :: PGresult -> MVar Int -> Int -> Int -> FieldDef -> (FieldDef -> CString -> Int -> IO a) -> IO a
getColValue pRes tupleIndex countTuples colNumber fieldDef f = do
index <- readMVar tupleIndex
when (index >= countTuples) (throwDyn SqlNoData)
isnull <- pqGetisnull pRes index colNumber
if isnull == 1
then f fieldDef nullPtr 0
else do
pStr <- pqGetvalue pRes index colNumber
strLen <- strlen pStr
f fieldDef pStr strLen
|