diff --git a/src/CodeGen_Vulkan_Dev.cpp b/src/CodeGen_Vulkan_Dev.cpp index 48f1468c1316..a1de1be84099 100644 --- a/src/CodeGen_Vulkan_Dev.cpp +++ b/src/CodeGen_Vulkan_Dev.cpp @@ -2086,31 +2086,21 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Shuffle *op) { debug(3) << "\n"; if (arg_ids.size() == 1) { - // 1 argument, just do a simple assignment via a cast SpvId result_id = cast_type(op->type, op->vectors[0].type(), arg_ids[0]); builder.update_id(result_id); } else if (arg_ids.size() == 2) { - - // 2 arguments, use a composite insert to update even and odd indices - uint32_t even_idx = 0; - uint32_t odd_idx = 1; - SpvFactory::Indices even_indices; - SpvFactory::Indices odd_indices; - for (int i = 0; i < op_lanes; ++i) { - even_indices.push_back(even_idx); - odd_indices.push_back(odd_idx); - even_idx += 2; - odd_idx += 2; + // 2 arguments, use vector-shuffle with logical indices indexing into (vec1[0], vec1[1], ..., vec2[0], vec2[1], ...) + SpvFactory::Indices logical_indices; + for (int i = 0; i < arg_lanes; ++i) { + logical_indices.push_back(uint32_t(i)); + logical_indices.push_back(uint32_t(i + arg_lanes)); } SpvId type_id = builder.declare_type(op->type); - SpvId value_id = builder.declare_null_constant(op->type); - SpvId partial_id = builder.reserve_id(SpvResultId); SpvId result_id = builder.reserve_id(SpvResultId); - builder.append(SpvFactory::composite_insert(type_id, partial_id, arg_ids[0], value_id, even_indices)); - builder.append(SpvFactory::composite_insert(type_id, result_id, arg_ids[1], partial_id, odd_indices)); + builder.append(SpvFactory::vector_shuffle(type_id, result_id, arg_ids[0], arg_ids[1], logical_indices)); builder.update_id(result_id); } else { @@ -2140,7 +2130,7 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Shuffle *op) { } else if (op->is_extract_element()) { int idx = op->indices[0]; internal_assert(idx >= 0); - internal_assert(idx <= op->vectors[0].type().lanes()); + internal_assert(idx < op->vectors[0].type().lanes()); if (op->vectors[0].type().is_vector()) { SpvFactory::Indices indices = {(uint32_t)idx}; SpvId type_id = builder.declare_type(op->type);