diff --git a/src/common.cpp b/src/common.cpp index 10bc5d0707f..c2faf6834c5 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -195,6 +195,26 @@ std::vector insert_common_args(module& m, else { auto common = common_shape(to_shapes(inputs)); + if(options.common_type) + { + // Exclude scalar literals from the common type decision. A scalar constant + // (e.g. an epsilon stored as float32) should adapt to the tensor's type rather + // than widening it. Without this, mixing a float32 scalar with a float16 tensor + // would promote the entire computation to float32, causing type mismatches in + // downstream ops (e.g. convolution, dot) whose weights remain float16. + std::vector tensor_shapes; + for(const auto& input : inputs) + { + if(not input->can_eval() or input->get_shape().elements() != 1) + tensor_shapes.push_back(input->get_shape()); + } + if(not tensor_shapes.empty()) + { + auto tensor_type = compute_common_types(tensor_shapes); + if(tensor_type != common.type()) + common = shape{tensor_type, common.lens()}; + } + } std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) { if(options.common_lens and input->get_shape().lens() != common.lens()) { diff --git a/src/onnx/parse_gru.cpp b/src/onnx/parse_gru.cpp index baa473dde4c..fee27af6025 100644 --- a/src/onnx/parse_gru.cpp +++ b/src/onnx/parse_gru.cpp @@ -1,7 +1,7 @@ /* * The MIT License (MIT) * - * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -166,7 +166,12 @@ struct parse_gru : op_parser { gru_transpose_inputs(info, args); } - + if(not args[5]->is_undefined() and + args[5]->get_shape().type() != args[1]->get_shape().type()) + { + args[5] = info.add_instruction( + make_op("convert", {{"target_type", args[1]->get_shape().type()}}), args[5]); + } // first output for concatenation of hidden states auto hidden_states = info.add_instruction(make_op("gru", diff --git a/test/common_test.cpp b/test/common_test.cpp new file mode 100644 index 00000000000..eb29cb4cdf6 --- /dev/null +++ b/test/common_test.cpp @@ -0,0 +1,85 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include "test.hpp" +#include +#include +#include +#include +#include + +TEST_CASE(add_common_op_scalar_literal_preserves_tensor_type) +{ + migraphx::module mm; + + auto tensor = mm.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {1, 100, 128}}); + auto scalar = + mm.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1e-5f}}); + + auto result = migraphx::add_common_op(mm, migraphx::make_op("add"), {tensor, scalar}); + + EXPECT(result->get_shape().type() == migraphx::shape::half_type); +} + +TEST_CASE(add_common_op_scalar_literal_preserves_tensor_type_reversed) +{ + migraphx::module mm; + + auto tensor = mm.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {1, 100, 128}}); + auto scalar = + mm.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1e-5f}}); + + auto result = migraphx::add_common_op(mm, migraphx::make_op("add"), {scalar, tensor}); + + EXPECT(result->get_shape().type() == migraphx::shape::half_type); +} + +TEST_CASE(add_common_op_two_tensors_promotes_to_wider_type) +{ + migraphx::module mm; + + auto half_tensor = mm.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {1, 128}}); + auto float_tensor = + mm.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {1, 128}}); + + auto result = + migraphx::add_common_op(mm, migraphx::make_op("add"), {half_tensor, float_tensor}); + + EXPECT(result->get_shape().type() == migraphx::shape::float_type); +} + +TEST_CASE(add_common_op_float_tensor_with_float_scalar_keeps_float) +{ + migraphx::module mm; + + auto tensor = + mm.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 100, 128}}); + auto scalar = + mm.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {1e-5f}}); + + auto result = migraphx::add_common_op(mm, migraphx::make_op("add"), {tensor, scalar}); + + EXPECT(result->get_shape().type() == migraphx::shape::float_type); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); }