Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mediapipe/tasks/cc/vision/image_segmenter/calculators/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/image_segmenter/proto:segmenter_options_cc_proto",
"//mediapipe/tasks/cc/vision/utils:image_utils",
"//mediapipe/util:label_map_cc_proto",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
Expand Down Expand Up @@ -86,6 +87,7 @@ cc_library(
"//mediapipe/tasks/cc/vision/utils:image_utils",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
] + select({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/port/status_macros.h"
Expand Down Expand Up @@ -538,7 +539,7 @@ bool SegmentationPostprocessorGl::HasGlExtension(std::string const& extension) {
#endif // __EMSCRIPTEN__
}

std::vector<std::unique_ptr<Image>>
absl::StatusOr<std::vector<std::unique_ptr<Image>>>
SegmentationPostprocessorGl::GetSegmentationResultGpu(
const Shape& input_shape, const Shape& output_shape, const Tensor& tensor,
const bool produce_confidence_masks, const bool produce_category_mask) {
Expand Down Expand Up @@ -598,14 +599,26 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(
// Otherwise, we just try for F16. See b/277656755 for more information.
// TODO: In the future, separate these 3 different restrictions.
// TODO: Also, we should extend this logic to all platforms.
static bool can_use_f32 = HasGlExtension("EXT_color_buffer_float") &&
HasGlExtension("OES_texture_float_linear") &&
HasGlExtension("EXT_float_blend");
static bool can_use_f16_backup =
const bool has_ext_color_buffer_float =
HasGlExtension("EXT_color_buffer_float");
const bool has_oes_texture_float_linear =
HasGlExtension("OES_texture_float_linear");
const bool has_ext_float_blend = HasGlExtension("EXT_float_blend");
const bool has_ext_color_buffer_half_float =
HasGlExtension("EXT_color_buffer_half_float");
RET_CHECK(can_use_f32 || can_use_f16_backup)
<< "Segmentation postprocessing error: GPU does not fully support "
<< "4-channel float32 or float16 formats.";
const bool can_use_f32 = has_ext_color_buffer_float &&
has_oes_texture_float_linear &&
has_ext_float_blend;
const bool can_use_f16_backup = has_ext_color_buffer_half_float;
if (!can_use_f32 && !can_use_f16_backup) {
return absl::FailedPreconditionError(absl::StrFormat(
"Segmentation postprocessing error: GPU does not fully support "
"4-channel float32 or float16 formats. WebGL extensions: "
"EXT_color_buffer_float=%v, OES_texture_float_linear=%v, "
"EXT_float_blend=%v, EXT_color_buffer_half_float=%v",
has_ext_color_buffer_float, has_oes_texture_float_linear,
has_ext_float_blend, has_ext_color_buffer_half_float));
}

const GpuBufferFormat activation_output_format =
can_use_f32 ? GpuBufferFormat::kRGBAFloat128
Expand Down Expand Up @@ -984,10 +997,8 @@ SegmentationPostprocessorGl::GetSegmentationResultGpu(
return absl::OkStatus();
});

if (!status.ok()) {
ABSL_LOG(ERROR) << "Error with rendering: " << status;
}

if (!status.ok())
return status;
return image_outputs;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#ifndef MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_CALCULATORS_SEGMENTATION_POSTPROCESSOR_GL_H_
#define MEDIAPIPE_TASKS_CC_VISION_IMAGE_SEGMENTER_CALCULATORS_SEGMENTATION_POSTPROCESSOR_GL_H_
#include "absl/status/statusor.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/formats/image.h"
#include "mediapipe/framework/formats/tensor.h"
Expand Down Expand Up @@ -41,7 +42,7 @@ class SegmentationPostprocessorGl {
absl::Status Initialize(
CalculatorContext* cc,
TensorsToSegmentationCalculatorOptions const& options);
std::vector<std::unique_ptr<Image>> GetSegmentationResultGpu(
absl::StatusOr<std::vector<std::unique_ptr<Image>>> GetSegmentationResultGpu(
const vision::Shape& input_shape, const vision::Shape& output_shape,
const Tensor& tensor, const bool produce_confidence_masks,
const bool produce_category_mask);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -434,49 +435,59 @@ absl::Status TensorsToSegmentationNodeImpl::Process(
options_.segmenter_options().output_type() ==
SegmenterOptions::CONFIDENCE_MASK ||
(cc.confidence_mask_out.Count() > 0);
std::vector<std::unique_ptr<Image>> segmented_masks =
postprocessor_.GetSegmentationResultGpu(
input_shape, output_shape, input_tensor, produce_confidence_masks,
produce_category_mask);
bool new_style = cc.category_mask_out.IsConnected() ||
(cc.confidence_mask_out.Count() > 0);
if (new_style) {
int confidence_mask_count =
options_.confidence_mask_options().output_channels().empty()
? input_shape.channels
: options_.confidence_mask_options().output_channels_size();
if (options_.confidence_mask_options().pack4()) {
confidence_mask_count = (confidence_mask_count + 3) / 4;
}
int category_mask_count = produce_category_mask ? 1 : 0;

// segmented_masks = [confidence_masks if any, category_mask if any]
if (produce_confidence_masks) {
RET_CHECK_EQ(segmented_masks.size(),
confidence_mask_count + category_mask_count)
<< "Confidence mask count mismatch.";
RET_CHECK_EQ(cc.confidence_mask_out.Count(),
segmented_masks.size() - category_mask_count)
<< "Confidence mask output count mismatch.";
for (int i = 0; i < confidence_mask_count; ++i) {
cc.confidence_mask_out.At(i).Send(std::move(segmented_masks[i]));
}
}
if (produce_category_mask) {
int category_mask_index =
produce_confidence_masks ? confidence_mask_count : 0;
cc.category_mask_out.Send(
std::move(segmented_masks[category_mask_index]));
auto gpu_segmented_masks = postprocessor_.GetSegmentationResultGpu(
input_shape, output_shape, input_tensor, produce_confidence_masks,
produce_category_mask);
if (!gpu_segmented_masks.ok()) {
if (gpu_segmented_masks.status().code() !=
absl::StatusCode::kFailedPrecondition) {
return gpu_segmented_masks.status();
}
ABSL_LOG(WARNING) << "Falling back to CPU segmentation postprocessing: "
<< gpu_segmented_masks.status();
} else {
// TODO: remove deprecated output type support.
RET_CHECK_EQ(cc.segmentation_out.Count(), segmented_masks.size())
<< "Segmentation output count mismatch.";
for (int i = 0; i < segmented_masks.size(); ++i) {
cc.segmentation_out.At(i).Send(std::move(segmented_masks[i]));
std::vector<std::unique_ptr<Image>> segmented_masks =
std::move(*gpu_segmented_masks);
bool new_style = cc.category_mask_out.IsConnected() ||
(cc.confidence_mask_out.Count() > 0);
if (new_style) {
int confidence_mask_count =
options_.confidence_mask_options().output_channels().empty()
? input_shape.channels
: options_.confidence_mask_options().output_channels_size();
if (options_.confidence_mask_options().pack4()) {
confidence_mask_count = (confidence_mask_count + 3) / 4;
}
int category_mask_count = produce_category_mask ? 1 : 0;

// segmented_masks = [confidence_masks if any, category_mask if any]
if (produce_confidence_masks) {
RET_CHECK_EQ(segmented_masks.size(),
confidence_mask_count + category_mask_count)
<< "Confidence mask count mismatch.";
RET_CHECK_EQ(cc.confidence_mask_out.Count(),
segmented_masks.size() - category_mask_count)
<< "Confidence mask output count mismatch.";
for (int i = 0; i < confidence_mask_count; ++i) {
cc.confidence_mask_out.At(i).Send(std::move(segmented_masks[i]));
}
}
if (produce_category_mask) {
int category_mask_index =
produce_confidence_masks ? confidence_mask_count : 0;
cc.category_mask_out.Send(
std::move(segmented_masks[category_mask_index]));
}
} else {
// TODO: remove deprecated output type support.
RET_CHECK_EQ(cc.segmentation_out.Count(), segmented_masks.size())
<< "Segmentation output count mismatch.";
for (int i = 0; i < segmented_masks.size(); ++i) {
cc.segmentation_out.At(i).Send(std::move(segmented_masks[i]));
}
}
return absl::OkStatus();
}
return absl::OkStatus();
}
#endif // TASK_SEGMENTATION_USE_GL_POSTPROCESSING

Expand Down