Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/AST.hs
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ data PrimFun sig where

-- local array operators
PrimVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, i) -> a)
PrimVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> PrimFun ((Vec n a, (i, a)) -> Vec n a)

-- general conversion between types
PrimFromIntegral :: IntegralType a -> NumType b -> PrimFun (a -> b)
Expand Down
1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/Analysis/Hash.hs
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ encodePrimFun (PrimNEq a) = intHost $(hashQ "PrimNEq")
encodePrimFun (PrimMax a) = intHost $(hashQ "PrimMax") <> encodeSingleType a
encodePrimFun (PrimMin a) = intHost $(hashQ "PrimMin") <> encodeSingleType a
encodePrimFun (PrimVectorIndex (VectorType _ a) b) = intHost $(hashQ "PrimVectorIndex") <> encodeSingleType a <> encodeNumType (IntegralNumType b)
encodePrimFun (PrimVectorWrite (VectorType _ a) b) = intHost $(hashQ "PrimVectorWrite") <> encodeSingleType a <> encodeNumType (IntegralNumType b)
encodePrimFun (PrimFromIntegral a b) = intHost $(hashQ "PrimFromIntegral") <> encodeIntegralType a <> encodeNumType b
encodePrimFun (PrimToFloating a b) = intHost $(hashQ "PrimToFloating") <> encodeNumType a <> encodeFloatingType b
encodePrimFun PrimLAnd = intHost $(hashQ "PrimLAnd")
Expand Down
1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/Classes/Vector.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import Data.Primitive.Vec
instance (VecElt a, KnownNat n) => Vectoring (Exp (Vec n a)) (Exp a) where
type IndexType (Exp (Vec n a)) = Exp Int
vecIndex = mkVectorIndex
vecWrite = mkVectorWrite
vecEmpty = undef


4 changes: 4 additions & 0 deletions src/Data/Array/Accelerate/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,7 @@ evalPrim PrimLAnd = evalLAnd
evalPrim PrimLOr = evalLOr
evalPrim PrimLNot = evalLNot
evalPrim (PrimVectorIndex v i) = evalVectorIndex v i
evalPrim (PrimVectorWrite v i) = evalVectorWrite v i
evalPrim (PrimFromIntegral ta tb) = evalFromIntegral ta tb
evalPrim (PrimToFloating ta tb) = evalToFloating ta tb

Expand Down Expand Up @@ -1174,6 +1175,9 @@ evalLNot = fromBool . not . toBool
evalVectorIndex :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, i) -> a
evalVectorIndex (VectorType n _) ti (v, i) | IntegralDict <- integralDict ti = vecIndex v (fromIntegral i)

evalVectorWrite :: (KnownNat n, Prim a) => VectorType (Vec n a) -> IntegralType i -> (Vec n a, (i, a)) -> Vec n a
evalVectorWrite (VectorType n _) ti (v, (i, a)) | IntegralDict <- integralDict ti = vecWrite v (fromIntegral i) a

evalFromIntegral :: IntegralType a -> NumType b -> a -> b
evalFromIntegral ta (IntegralNumType tb)
| IntegralDict <- integralDict ta
Expand Down
9 changes: 9 additions & 0 deletions src/Data/Array/Accelerate/Smart.hs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ module Data.Array.Accelerate.Smart (
-- ** Smart constructors for vector operations
mkVectorCreate,
mkVectorIndex,
mkVectorWrite,

-- ** Auxiliary functions
($$), ($$$), ($$$$), ($$$$$),
Expand Down Expand Up @@ -1190,6 +1191,11 @@ mkVectorIndex = let n :: Int
n = fromIntegral $ natVal $ Proxy @n
in mkPrimBinary $ PrimVectorIndex @n (VectorType n singleType) integralType

mkVectorWrite :: forall n a. (KnownNat n, VecElt a) => Exp (Vec n a) -> Exp Int -> Exp a -> Exp (Vec n a)
mkVectorWrite = let n :: Int
n = fromIntegral $ natVal $ Proxy @n
in mkPrimTernary $ PrimVectorWrite @n (VectorType n singleType) integralType

-- Numeric conversions

mkFromIntegral :: (Elt a, Elt b, IsIntegral (EltR a), IsNum (EltR b)) => Exp a -> Exp b
Expand Down Expand Up @@ -1277,6 +1283,9 @@ mkPrimUnary prim (Exp a) = mkExp $ PrimApp prim a
mkPrimBinary :: (Elt a, Elt b, Elt c) => PrimFun ((EltR a, EltR b) -> EltR c) -> Exp a -> Exp b -> Exp c
mkPrimBinary prim (Exp a) (Exp b) = mkExp $ PrimApp prim (SmartExp $ Pair a b)

mkPrimTernary :: (Elt a, Elt b, Elt c, Elt d) => PrimFun ((EltR a, (EltR b, EltR c)) -> EltR d) -> Exp a -> Exp b -> Exp c -> Exp d
mkPrimTernary prim (Exp a) (Exp b) (Exp c) = mkExp $ PrimApp prim (SmartExp $ Pair a (SmartExp (Pair b c)))

mkPrimUnaryBool :: Elt a => PrimFun (EltR a -> PrimBool) -> Exp a -> Exp Bool
mkPrimUnaryBool = mkCoerce @PrimBool $$ mkPrimUnary

Expand Down
1 change: 1 addition & 0 deletions src/Data/Array/Accelerate/Trafo/Algebra.hs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ evalPrimApp env f x
PrimMax ty -> evalMax ty x env
PrimMin ty -> evalMin ty x env
PrimVectorIndex _ _ -> Nothing
PrimVectorWrite _ _ -> Nothing
PrimLAnd -> evalLAnd x env
PrimLOr -> evalLOr x env
PrimLNot -> evalLNot x env
Expand Down
10 changes: 10 additions & 0 deletions src/Data/Primitive/Vec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TupleSections #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module : Data.Primitive.Vec
Expand Down Expand Up @@ -96,6 +97,7 @@ data Vec (n :: Nat) a = Vec ByteArray#
class Vectoring vector a | vector -> a where
type IndexType vector :: Data.Kind.Type
vecIndex :: vector -> IndexType vector -> a
vecWrite :: vector -> IndexType vector -> a -> vector
vecEmpty :: vector

instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where
Expand All @@ -104,6 +106,14 @@ instance (KnownNat n, Prim a) => Vectoring (Vec n a) a where
n :: Int
n = fromIntegral $ natVal $ Proxy @n
in if i >= 0 && i < n then indexByteArray# ba# iu# else error ("index " <> show i <> " out of range in Vec of size " <> show n)
vecWrite vec@(Vec ba#) i@(I# iu#) v = runST $ do
let n :: Int
n = fromIntegral $ natVal $ Proxy @n
mba <- newByteArray (n * sizeOf (undefined :: a))
let new_vs = zipWith (\i' v' -> if i' == i then v else v') [0..n] (listOfVec vec)
zipWithM_ (writeByteArray mba) [0..n] new_vs
ByteArray nba# <- unsafeFreezeByteArray mba
return $! Vec nba#
vecEmpty = mkVec


Expand Down