Skip to content

TFMA Writers

tensorflow_model_analysis.writers

Init module for TensorFlow Model Analysis writers.

Attributes

Writer module-attribute

Writer = NamedTuple(
    "Writer",
    [("stage_name", str), ("ptransform", PTransform)],
)

Functions

EvalConfigWriter

EvalConfigWriter(
    output_path: str,
    eval_config: EvalConfig,
    output_file_format: str = EVAL_CONFIG_FILE_FORMAT,
    data_location: Optional[str] = None,
    data_file_format: Optional[str] = None,
    model_locations: Optional[Dict[str, str]] = None,
    filename: Optional[str] = None,
) -> Writer

Returns eval config writer.

PARAMETER DESCRIPTION
output_path

Output path to write config to.

TYPE: str

eval_config

EvalConfig.

TYPE: EvalConfig

output_file_format

Output file format. Currently on 'json' is supported.

TYPE: str DEFAULT: EVAL_CONFIG_FILE_FORMAT

data_location

Optional path indicating where data is read from. This is only used for display purposes.

TYPE: Optional[str] DEFAULT: None

data_file_format

Optional format of the input examples. This is only used for display purposes.

TYPE: Optional[str] DEFAULT: None

model_locations

Dict of model locations keyed by model name. This is only used for display purposes.

TYPE: Optional[Dict[str, str]] DEFAULT: None

filename

Name of file to store the config as.

TYPE: Optional[str] DEFAULT: None

Source code in tensorflow_model_analysis/writers/eval_config_writer.py
def EvalConfigWriter(  # pylint: disable=invalid-name
    output_path: str,
    eval_config: config_pb2.EvalConfig,
    output_file_format: str = EVAL_CONFIG_FILE_FORMAT,
    data_location: Optional[str] = None,
    data_file_format: Optional[str] = None,
    model_locations: Optional[Dict[str, str]] = None,
    filename: Optional[str] = None) -> writer.Writer:
  """Returns eval config writer.

  Args:
    output_path: Output path to write config to.
    eval_config: EvalConfig.
    output_file_format: Output file format. Currently on 'json' is supported.
    data_location: Optional path indicating where data is read from. This is
      only used for display purposes.
    data_file_format: Optional format of the input examples. This is only used
      for display purposes.
    model_locations: Dict of model locations keyed by model name. This is only
      used for display purposes.
    filename: Name of file to store the config as.
  """
  if data_location is None:
    data_location = '<user provided PCollection>'
  if data_file_format is None:
    data_file_format = '<unknown>'
  if model_locations is None:
    model_locations = {'': '<unknown>'}
  if filename is None:
    filename = EVAL_CONFIG_FILE + '.' + output_file_format

  return writer.Writer(
      stage_name='WriteEvalConfig',
      ptransform=_WriteEvalConfig(  # pylint: disable=no-value-for-parameter
          eval_config=eval_config,
          output_path=output_path,
          output_file_format=output_file_format,
          data_location=data_location,
          data_file_format=data_file_format,
          model_locations=model_locations,
          filename=filename))

MetricsPlotsAndValidationsWriter

MetricsPlotsAndValidationsWriter(
    output_paths: Dict[str, str],
    eval_config: EvalConfig,
    add_metrics_callbacks: Optional[
        List[AddMetricsCallbackType]
    ] = None,
    metrics_key: str = METRICS_KEY,
    plots_key: str = PLOTS_KEY,
    attributions_key: str = ATTRIBUTIONS_KEY,
    validations_key: str = VALIDATIONS_KEY,
    output_file_format: str = _TFRECORD_FORMAT,
    rubber_stamp: Optional[bool] = False,
    stage_name: str = METRICS_PLOTS_AND_VALIDATIONS_WRITER_STAGE_NAME,
) -> Writer

Returns metrics and plots writer.

Note, sharding will be enabled by default if a output_file_format is provided. The files will be named -SSSSS-of-NNNNN. where SSSSS is the shard number and NNNNN is the number of shards.

PARAMETER DESCRIPTION
output_paths

Output paths keyed by output key (e.g. 'metrics', 'plots', 'validation').

TYPE: Dict[str, str]

eval_config

Eval config.

TYPE: EvalConfig

add_metrics_callbacks

Optional list of metric callbacks (if used).

TYPE: Optional[List[AddMetricsCallbackType]] DEFAULT: None

metrics_key

Name to use for metrics key in Evaluation output.

TYPE: str DEFAULT: METRICS_KEY

plots_key

Name to use for plots key in Evaluation output.

TYPE: str DEFAULT: PLOTS_KEY

attributions_key

Name to use for attributions key in Evaluation output.

TYPE: str DEFAULT: ATTRIBUTIONS_KEY

validations_key

Name to use for validations key in Evaluation output.

TYPE: str DEFAULT: VALIDATIONS_KEY

output_file_format

File format to use when saving files. Currently 'tfrecord' and 'parquet' are supported and 'tfrecord is the default'. If using parquet, the output metrics and plots files will contain two columns, 'slice_key' and 'serialized_value'. The 'slice_key' column will be a structured column matching the metrics_for_slice_pb2.SliceKey proto. The 'serialized_value' column will contain a serialized MetricsForSlice or PlotsForSlice proto. The validation result file will contain a single column 'serialized_value' which will contain a single serialized ValidationResult proto.

TYPE: str DEFAULT: _TFRECORD_FORMAT

rubber_stamp

True if this model is being rubber stamped. When a model is rubber stamped diff thresholds will be ignored if an associated baseline model is not passed.

TYPE: Optional[bool] DEFAULT: False

stage_name

The stage name to use when this writer is added to the Beam pipeline.

TYPE: str DEFAULT: METRICS_PLOTS_AND_VALIDATIONS_WRITER_STAGE_NAME

Source code in tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py
def MetricsPlotsAndValidationsWriter(  # pylint: disable=invalid-name
    output_paths: Dict[str, str],
    eval_config: config_pb2.EvalConfig,
    add_metrics_callbacks: Optional[List[types.AddMetricsCallbackType]] = None,
    metrics_key: str = constants.METRICS_KEY,
    plots_key: str = constants.PLOTS_KEY,
    attributions_key: str = constants.ATTRIBUTIONS_KEY,
    validations_key: str = constants.VALIDATIONS_KEY,
    output_file_format: str = _TFRECORD_FORMAT,
    rubber_stamp: Optional[bool] = False,
    stage_name: str = METRICS_PLOTS_AND_VALIDATIONS_WRITER_STAGE_NAME
) -> writer.Writer:
  """Returns metrics and plots writer.

  Note, sharding will be enabled by default if a output_file_format is provided.
  The files will be named <output_path>-SSSSS-of-NNNNN.<output_file_format>
  where SSSSS is the shard number and NNNNN is the number of shards.

  Args:
    output_paths: Output paths keyed by output key (e.g. 'metrics', 'plots',
      'validation').
    eval_config: Eval config.
    add_metrics_callbacks: Optional list of metric callbacks (if used).
    metrics_key: Name to use for metrics key in Evaluation output.
    plots_key: Name to use for plots key in Evaluation output.
    attributions_key: Name to use for attributions key in Evaluation output.
    validations_key: Name to use for validations key in Evaluation output.
    output_file_format: File format to use when saving files. Currently
      'tfrecord' and 'parquet' are supported and 'tfrecord is the default'.
      If using parquet, the output metrics and plots files will contain two
      columns, 'slice_key' and 'serialized_value'. The 'slice_key' column will
      be a structured column matching the metrics_for_slice_pb2.SliceKey proto.
      The 'serialized_value' column will contain a serialized MetricsForSlice or
      PlotsForSlice proto. The validation result file will contain a single
      column 'serialized_value' which will contain a single serialized
      ValidationResult proto.
    rubber_stamp: True if this model is being rubber stamped. When a model is
      rubber stamped diff thresholds will be ignored if an associated baseline
      model is not passed.
    stage_name: The stage name to use when this writer is added to the Beam
      pipeline.
  """
  return writer.Writer(
      stage_name=stage_name,
      ptransform=_WriteMetricsPlotsAndValidations(  # pylint: disable=no-value-for-parameter
          output_paths=output_paths,
          eval_config=eval_config,
          add_metrics_callbacks=add_metrics_callbacks or [],
          metrics_key=metrics_key,
          plots_key=plots_key,
          attributions_key=attributions_key,
          validations_key=validations_key,
          output_file_format=output_file_format,
          rubber_stamp=rubber_stamp))

Write

Write(
    evaluation_or_validation: Union[Evaluation, Validation],
    key: str,
    ptransform: PTransform,
) -> Optional[PCollection]

Writes given Evaluation or Validation data using given writer PTransform.

PARAMETER DESCRIPTION
evaluation_or_validation

Evaluation or Validation data.

TYPE: Union[Evaluation, Validation]

key

Key for Evaluation or Validation output to write. It is valid for the key to not exist in the dict (in which case the write is a no-op).

TYPE: str

ptransform

PTransform to use for writing.

TYPE: PTransform

RAISES DESCRIPTION
ValueError

If Evaluation or Validation is empty. The key does not need to exist in the Evaluation or Validation, but the dict must not be empty.

RETURNS DESCRIPTION
Optional[PCollection]

The result of the underlying beam write PTransform. This makes it possible

Optional[PCollection]

for interactive environments to execute your writer, as well as for

Optional[PCollection]

downstream Beam stages to make use of the files that are written.

Source code in tensorflow_model_analysis/writers/writer.py
@beam.ptransform_fn
def Write(evaluation_or_validation: Union[evaluator.Evaluation,
                                          validator.Validation], key: str,
          ptransform: beam.PTransform) -> Optional[beam.PCollection]:
  """Writes given Evaluation or Validation data using given writer PTransform.

  Args:
    evaluation_or_validation: Evaluation or Validation data.
    key: Key for Evaluation or Validation output to write. It is valid for the
      key to not exist in the dict (in which case the write is a no-op).
    ptransform: PTransform to use for writing.

  Raises:
    ValueError: If Evaluation or Validation is empty. The key does not need to
      exist in the Evaluation or Validation, but the dict must not be empty.

  Returns:
    The result of the underlying beam write PTransform. This makes it possible
    for interactive environments to execute your writer, as well as for
    downstream Beam stages to make use of the files that are written.
  """
  if not evaluation_or_validation:
    raise ValueError('Evaluations and Validations cannot be empty')
  if key in evaluation_or_validation:
    return evaluation_or_validation[key] | ptransform
  return None

convert_slice_metrics_to_proto

convert_slice_metrics_to_proto(
    metrics: Tuple[
        SliceKeyOrCrossSliceKeyType, MetricsDict
    ],
    add_metrics_callbacks: Optional[
        List[AddMetricsCallbackType]
    ],
) -> MetricsForSlice

Converts the given slice metrics into serialized proto MetricsForSlice.

PARAMETER DESCRIPTION
metrics

The slice metrics.

TYPE: Tuple[SliceKeyOrCrossSliceKeyType, MetricsDict]

add_metrics_callbacks

A list of metric callbacks. This should be the same list as the one passed to tfma.Evaluate().

TYPE: Optional[List[AddMetricsCallbackType]]

RETURNS DESCRIPTION
MetricsForSlice

The MetricsForSlice proto.

RAISES DESCRIPTION
TypeError

If the type of the feature value in slice key cannot be recognized.

Source code in tensorflow_model_analysis/writers/metrics_plots_and_validations_writer.py
def convert_slice_metrics_to_proto(
    metrics: Tuple[slicer.SliceKeyOrCrossSliceKeyType,
                   metric_types.MetricsDict],
    add_metrics_callbacks: Optional[List[types.AddMetricsCallbackType]]
) -> metrics_for_slice_pb2.MetricsForSlice:
  """Converts the given slice metrics into serialized proto MetricsForSlice.

  Args:
    metrics: The slice metrics.
    add_metrics_callbacks: A list of metric callbacks. This should be the same
      list as the one passed to tfma.Evaluate().

  Returns:
    The MetricsForSlice proto.

  Raises:
    TypeError: If the type of the feature value in slice key cannot be
      recognized.
  """
  result = metrics_for_slice_pb2.MetricsForSlice()
  slice_key, slice_metrics = metrics

  if slicer.is_cross_slice_key(slice_key):
    result.cross_slice_key.CopyFrom(slicer.serialize_cross_slice_key(slice_key))
  else:
    result.slice_key.CopyFrom(slicer.serialize_slice_key(slice_key))

  slice_metrics = slice_metrics.copy()

  if metric_keys.ERROR_METRIC in slice_metrics:
    logging.warning('Error for slice: %s with error message: %s ', slice_key,
                    slice_metrics[metric_keys.ERROR_METRIC])
    result.metrics[metric_keys.ERROR_METRIC].debug_message = slice_metrics[
        metric_keys.ERROR_METRIC]
    return result

  # Convert the metrics from add_metrics_callbacks to the structured output if
  # defined.
  if add_metrics_callbacks and (
      not any(isinstance(k, metric_types.MetricKey) for k in slice_metrics)
  ):
    for add_metrics_callback in add_metrics_callbacks:
      if hasattr(add_metrics_callback, 'populate_stats_and_pop'):
        add_metrics_callback.populate_stats_and_pop(slice_key, slice_metrics,
                                                    result.metrics)
  for key in sorted(slice_metrics):
    value = slice_metrics[key]
    if isinstance(value, types.ValueWithTDistribution):
      unsampled_value = value.unsampled_value
      _, lower_bound, upper_bound = (
          math_util.calculate_confidence_interval(value))
      confidence_interval = metrics_for_slice_pb2.ConfidenceInterval(
          lower_bound=convert_metric_value_to_proto(lower_bound),
          upper_bound=convert_metric_value_to_proto(upper_bound),
          standard_error=convert_metric_value_to_proto(
              value.sample_standard_deviation),
          degrees_of_freedom={'value': value.sample_degrees_of_freedom})
      metric_value = convert_metric_value_to_proto(unsampled_value)
      if isinstance(key, metric_types.MetricKey):
        result.metric_keys_and_values.add(
            key=key.to_proto(),
            value=metric_value,
            confidence_interval=confidence_interval)
      else:
        # For v1 we continue to populate bounded_value for backwards
        # compatibility. If metric can be stored to double_value metrics,
        # replace it with a bounded_value.
        # TODO(b/171992041): remove the string-typed metric key branch once v1
        # code is removed.
        if metric_value.WhichOneof('type') == 'double_value':
          # setting bounded_value clears double_value in the same oneof scope.
          metric_value.bounded_value.value.value = unsampled_value
          metric_value.bounded_value.lower_bound.value = lower_bound
          metric_value.bounded_value.upper_bound.value = upper_bound
          metric_value.bounded_value.methodology = (
              metrics_for_slice_pb2.BoundedValue.POISSON_BOOTSTRAP)
        result.metrics[key].CopyFrom(metric_value)
    elif isinstance(value, metrics_for_slice_pb2.BoundedValue):
      metric_value = metrics_for_slice_pb2.MetricValue(
          double_value=wrappers_pb2.DoubleValue(value=value.value.value))
      confidence_interval = metrics_for_slice_pb2.ConfidenceInterval(
          lower_bound=metrics_for_slice_pb2.MetricValue(
              double_value=wrappers_pb2.DoubleValue(
                  value=value.lower_bound.value)),
          upper_bound=metrics_for_slice_pb2.MetricValue(
              double_value=wrappers_pb2.DoubleValue(
                  value=value.upper_bound.value)))
      result.metric_keys_and_values.add(
          key=key.to_proto(),
          value=metric_value,
          confidence_interval=confidence_interval)
    else:
      metric_value = convert_metric_value_to_proto(value)
      if isinstance(key, metric_types.MetricKey):
        result.metric_keys_and_values.add(
            key=key.to_proto(), value=metric_value)
      else:
        # TODO(b/171992041): remove the string-typed metric key branch once v1
        # code is removed.
        result.metrics[key].CopyFrom(metric_value)
  return result