Skip to content

TFMA Utils

tensorflow_model_analysis.utils

Init module for TensorFlow Model Analysis utils.

Classes

CombineFnWithModels

CombineFnWithModels(model_loaders: Dict[str, ModelLoader])

Bases: CombineFn

Abstract class for CombineFns that need the shared models.

Initializes CombineFn using dict of loaders keyed by model location.

Source code in tensorflow_model_analysis/utils/model_util.py
def __init__(self, model_loaders: Dict[str, types.ModelLoader]):
  """Initializes CombineFn using dict of loaders keyed by model location."""
  self._model_loaders = model_loaders
  self._loaded_models = None
  self._model_load_seconds = None
  self._model_load_seconds_distribution = beam.metrics.Metrics.distribution(
      constants.METRICS_NAMESPACE, 'model_load_seconds')
Functions
setup
setup()
Source code in tensorflow_model_analysis/utils/model_util.py
def setup(self):
  if self._loaded_models is None:
    self._loaded_models = {}
    for model_name, model_loader in self._model_loaders.items():
      self._loaded_models[model_name] = model_loader.load(
          model_load_time_callback=self._set_model_load_seconds)
    if self._model_load_seconds is not None:
      self._model_load_seconds_distribution.update(self._model_load_seconds)
      self._model_load_seconds = None

DoFnWithModels

DoFnWithModels(model_loaders: Dict[str, ModelLoader])

Bases: DoFn

Abstract class for DoFns that need the shared models.

Initializes DoFn using dict of model loaders keyed by model location.

Source code in tensorflow_model_analysis/utils/model_util.py
def __init__(self, model_loaders: Dict[str, types.ModelLoader]):
  """Initializes DoFn using dict of model loaders keyed by model location."""
  self._model_loaders = model_loaders
  self._loaded_models = None
  self._model_load_seconds = None
  self._model_load_seconds_distribution = beam.metrics.Metrics.distribution(
      constants.METRICS_NAMESPACE, 'model_load_seconds')
Functions
finish_bundle
finish_bundle()
Source code in tensorflow_model_analysis/utils/model_util.py
def finish_bundle(self):
  # Must update distribution in finish_bundle instead of setup
  # because Beam metrics are not supported in setup.
  if self._model_load_seconds is not None:
    self._model_load_seconds_distribution.update(self._model_load_seconds)
    self._model_load_seconds = None
process
process(elem)
Source code in tensorflow_model_analysis/utils/model_util.py
def process(self, elem):
  raise NotImplementedError('Subclasses are expected to override this.')
setup
setup()
Source code in tensorflow_model_analysis/utils/model_util.py
def setup(self):
  self._loaded_models = {}
  for model_name, model_loader in self._model_loaders.items():
    self._loaded_models[model_name] = model_loader.load(
        model_load_time_callback=self._set_model_load_seconds)

Functions

calculate_confidence_interval

calculate_confidence_interval(
    t_distribution_value: ValueWithTDistribution,
)

Calculate confidence intervals based 95% confidence level.

Source code in tensorflow_model_analysis/utils/math_util.py
def calculate_confidence_interval(
    t_distribution_value: types.ValueWithTDistribution):
  """Calculate confidence intervals based 95% confidence level."""
  alpha = 0.05
  std_err = t_distribution_value.sample_standard_deviation
  t_stat = stats.t.ppf(1 - (alpha / 2.0),
                       t_distribution_value.sample_degrees_of_freedom)
  # The order of operands matters here because we want to use the
  # std_err.__mul__ operator below, rather than the t_stat.__mul__.
  # TODO(b/197669322): make StructuredMetricValues robust to operand ordering.
  upper_bound = t_distribution_value.sample_mean + std_err * t_stat
  lower_bound = t_distribution_value.sample_mean - std_err * t_stat
  return t_distribution_value.sample_mean, lower_bound, upper_bound

compound_key

compound_key(
    keys: Sequence[str], separator: str = KEY_SEPARATOR
) -> str

Returns a compound key based on a list of keys.

PARAMETER DESCRIPTION
keys

Keys used to make up compound key.

TYPE: Sequence[str]

separator

Separator between keys. To ensure the keys can be parsed out of any compound key created, any use of a separator within a key will be replaced by two separators.

TYPE: str DEFAULT: KEY_SEPARATOR

Source code in tensorflow_model_analysis/utils/util.py
def compound_key(keys: Sequence[str], separator: str = KEY_SEPARATOR) -> str:
  """Returns a compound key based on a list of keys.

  Args:
    keys: Keys used to make up compound key.
    separator: Separator between keys. To ensure the keys can be parsed out of
      any compound key created, any use of a separator within a key will be
      replaced by two separators.
  """
  return separator.join([key.replace(separator, separator * 2) for key in keys])

create_keys_key

create_keys_key(key: str) -> str

Creates secondary key representing the sparse keys associated with key.

Source code in tensorflow_model_analysis/utils/util.py
def create_keys_key(key: str) -> str:
  """Creates secondary key representing the sparse keys associated with key."""
  return '_'.join([key, KEYS_SUFFIX])

create_values_key

create_values_key(key: str) -> str

Creates secondary key representing sparse values associated with key.

Source code in tensorflow_model_analysis/utils/util.py
def create_values_key(key: str) -> str:
  """Creates secondary key representing sparse values associated with key."""
  return '_'.join([key, VALUES_SUFFIX])

get_baseline_model_spec

get_baseline_model_spec(
    eval_config: EvalConfig,
) -> Optional[ModelSpec]

Returns baseline model spec.

Source code in tensorflow_model_analysis/utils/model_util.py
def get_baseline_model_spec(
    eval_config: config_pb2.EvalConfig) -> Optional[config_pb2.ModelSpec]:
  """Returns baseline model spec."""
  for spec in eval_config.model_specs:
    if spec.is_baseline:
      return spec
  return None

get_by_keys

get_by_keys(
    data: Mapping[str, Any],
    keys: Sequence[Any],
    default_value=None,
    optional: bool = False,
) -> Any

Returns value with given key(s) in (possibly multi-level) dict.

The keys represent multiple levels of indirection into the data. For example if 3 keys are passed then the data is expected to be a dict of dict of dict. For compatibily with data that uses prefixing to create separate the keys in a single dict, lookups will also be searched for under the keys separated by '/'. For example, the keys 'head1' and 'probabilities' could be stored in a a single dict as 'head1/probabilties'.

PARAMETER DESCRIPTION
data

Dict to get value from.

TYPE: Mapping[str, Any]

keys

Sequence of keys to lookup in data. None keys will be ignored.

TYPE: Sequence[Any]

default_value

Default value if not found.

DEFAULT: None

optional

Whether the key is optional or not. If default value is None and optional is False then a ValueError will be raised if key not found.

TYPE: bool DEFAULT: False

RAISES DESCRIPTION
ValueError

If (non-optional) key is not found.

Source code in tensorflow_model_analysis/utils/util.py
def get_by_keys(
    data: Mapping[str, Any],
    keys: Sequence[Any],
    default_value=None,
    optional: bool = False,
) -> Any:
  """Returns value with given key(s) in (possibly multi-level) dict.

  The keys represent multiple levels of indirection into the data. For example
  if 3 keys are passed then the data is expected to be a dict of dict of dict.
  For compatibily with data that uses prefixing to create separate the keys in a
  single dict, lookups will also be searched for under the keys separated by
  '/'. For example, the keys 'head1' and 'probabilities' could be stored in a
  a single dict as 'head1/probabilties'.

  Args:
    data: Dict to get value from.
    keys: Sequence of keys to lookup in data. None keys will be ignored.
    default_value: Default value if not found.
    optional: Whether the key is optional or not. If default value is None and
      optional is False then a ValueError will be raised if key not found.

  Raises:
    ValueError: If (non-optional) key is not found.
  """
  if not keys:
    raise ValueError('no keys provided to get_by_keys: %s' % data)

  format_keys = lambda keys: '->'.join([str(k) for k in keys if k is not None])

  value = data
  keys_matched = 0
  for i, key in enumerate(keys):
    if key is None:
      keys_matched += 1
      continue

    if not isinstance(value, Mapping):
      raise ValueError('expected dict for "%s" but found %s: %s' %
                       (format_keys(keys[:i + 1]), type(value), data))

    if key in value:
      value = value[key]
      keys_matched += 1
      continue

    # If values have prefixes matching the key, return those values (stripped
    # of the prefix) instead.
    prefix_matches = {}
    for k, v in value.items():
      if k.startswith(key + '/'):
        prefix_matches[k[len(key) + 1:]] = v
    if prefix_matches:
      value = prefix_matches
      keys_matched += 1
      continue

    break

  if keys_matched < len(keys) or isinstance(value, Mapping) and not value:
    if default_value is not None:
      return default_value
    if optional:
      return None
    raise ValueError('"%s" key not found (or value is empty dict): %s' %
                     (format_keys(keys[:keys_matched + 1]), data))
  return value

get_model_spec

get_model_spec(
    eval_config: EvalConfig, model_name: str
) -> Optional[ModelSpec]

Returns model spec with given model name.

Source code in tensorflow_model_analysis/utils/model_util.py
def get_model_spec(eval_config: config_pb2.EvalConfig,
                   model_name: str) -> Optional[config_pb2.ModelSpec]:
  """Returns model spec with given model name."""
  if len(eval_config.model_specs) == 1 and not model_name:
    return eval_config.model_specs[0]
  for spec in eval_config.model_specs:
    if spec.name == model_name:
      return spec
  return None

get_model_type

get_model_type(
    model_spec: Optional[ModelSpec],
    model_path: Optional[str] = "",
    tags: Optional[List[str]] = None,
) -> str

Returns model type for given model spec taking into account defaults.

The defaults are chosen such that if a model_path is provided and the model can be loaded as a keras model then TF_KERAS is assumed. Next, if tags are provided and the tags contains 'eval' then TF_ESTIMATOR is assumed. Lastly, if the model spec contains an 'eval' signature TF_ESTIMATOR is assumed otherwise TF_GENERIC is assumed.

PARAMETER DESCRIPTION
model_spec

Model spec.

TYPE: Optional[ModelSpec]

model_path

Optional model path to verify if keras model.

TYPE: Optional[str] DEFAULT: ''

tags

Options tags to verify if eval is used.

TYPE: Optional[List[str]] DEFAULT: None

Source code in tensorflow_model_analysis/utils/model_util.py
def get_model_type(model_spec: Optional[config_pb2.ModelSpec],
                   model_path: Optional[str] = '',
                   tags: Optional[List[str]] = None) -> str:
  """Returns model type for given model spec taking into account defaults.

  The defaults are chosen such that if a model_path is provided and the model
  can be loaded as a keras model then TF_KERAS is assumed. Next, if tags
  are provided and the tags contains 'eval' then TF_ESTIMATOR is assumed.
  Lastly, if the model spec contains an 'eval' signature TF_ESTIMATOR is assumed
  otherwise TF_GENERIC is assumed.

  Args:
    model_spec: Model spec.
    model_path: Optional model path to verify if keras model.
    tags: Options tags to verify if eval is used.
  """
  if model_spec and model_spec.model_type:
    return model_spec.model_type

  if model_path:
    try:
      keras_model = tf.keras.models.load_model(model_path)
      # In some cases, tf.keras.models.load_model can successfully load a
      # saved_model but it won't actually be a keras model.
      if isinstance(keras_model, tf.keras.models.Model):
        return constants.TF_KERAS
    except Exception:  # pylint: disable=broad-except
      pass

  signature_name = None
  if model_spec:
    if model_spec.signature_name:
      signature_name = model_spec.signature_name
    else:
      signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY

  return constants.TF_GENERIC

get_non_baseline_model_specs

get_non_baseline_model_specs(
    eval_config: EvalConfig,
) -> Iterable[ModelSpec]

Returns non-baseline model specs.

Source code in tensorflow_model_analysis/utils/model_util.py
def get_non_baseline_model_specs(
    eval_config: config_pb2.EvalConfig) -> Iterable[config_pb2.ModelSpec]:
  """Returns non-baseline model specs."""
  return [spec for spec in eval_config.model_specs if not spec.is_baseline]

has_change_threshold

has_change_threshold(eval_config: EvalConfig) -> bool

Checks whether the eval_config has any change thresholds.

PARAMETER DESCRIPTION
eval_config

the TFMA eval_config.

TYPE: EvalConfig

RETURNS DESCRIPTION
bool

True when there are change thresholds otherwise False.

Source code in tensorflow_model_analysis/utils/config_util.py
def has_change_threshold(eval_config: config_pb2.EvalConfig) -> bool:
  """Checks whether the eval_config has any change thresholds.

  Args:
    eval_config: the TFMA eval_config.

  Returns:
    True when there are change thresholds otherwise False.
  """

  for metrics_spec in eval_config.metrics_specs:
    for metric in metrics_spec.metrics:
      if metric.threshold.change_threshold.ByteSize():
        return True
      for per_slice_threshold in metric.per_slice_thresholds:
        if per_slice_threshold.threshold.change_threshold.ByteSize():
          return True
      for cross_slice_threshold in metric.cross_slice_thresholds:
        if cross_slice_threshold.threshold.change_threshold.ByteSize():
          return True
    for threshold in metrics_spec.thresholds.values():
      if threshold.change_threshold.ByteSize():
        return True
    for per_slice_thresholds in metrics_spec.per_slice_thresholds.values():
      for per_slice_threshold in per_slice_thresholds.thresholds:
        if per_slice_threshold.threshold.change_threshold.ByteSize():
          return True
    for cross_slice_thresholds in metrics_spec.cross_slice_thresholds.values():
      for cross_slice_threshold in cross_slice_thresholds.thresholds:
        if cross_slice_threshold.threshold.change_threshold.ByteSize():
          return True
  return False

merge_extracts

merge_extracts(
    extracts: List[Extracts],
    squeeze_two_dim_vector: bool = True,
) -> Extracts

Merges list of extracts into a single extract with multidimensional data.

Running split_extracts followed by merge extracts with default options

will not reproduce the exact shape of the original extracts. Arrays in shape (x,1) will be flattened to (x,). To maintain the original shape of extract values of array shape (x,1), you must run with these options: split_extracts(extracts, expand_zero_dims=False) merge_extracts(extracts, squeeze_two_dim_vector=False)

Args: extracts: Batched TFMA Extracts. squeeze_two_dim_vector: Determines how the function will handle arrays of shape (x,1). If squeeze_two_dim_vector is True, the array will be squeezed to shape (x,).

RETURNS DESCRIPTION
Extracts

A single Extracts whose values have been grouped into batches.

Source code in tensorflow_model_analysis/utils/util.py
def merge_extracts(extracts: List[types.Extracts],
                   squeeze_two_dim_vector: bool = True) -> types.Extracts:
  """Merges list of extracts into a single extract with multidimensional data.

  Note: Running split_extracts followed by merge extracts with default options
    will not reproduce the exact shape of the original extracts. Arrays in shape
    (x,1) will be flattened to (x,). To maintain the original shape of extract
    values of array shape (x,1), you must run with these options:
    split_extracts(extracts, expand_zero_dims=False)
    merge_extracts(extracts, squeeze_two_dim_vector=False)
  Args:
    extracts: Batched TFMA Extracts.
    squeeze_two_dim_vector: Determines how the function will handle arrays of
      shape (x,1). If squeeze_two_dim_vector is True, the array will be squeezed
      to shape (x,).

  Returns:
    A single Extracts whose values have been grouped into batches.
  """

  def merge_with_lists(target: types.Extracts, index: int, key: str, value: Any,
                       num_extracts: int):
    """Merges key and value into the target extracts as a list of values.

    Args:
     target: The extract to store all merged all the data.
     index: The index at which the value should be stored. It is in accordance
       with the order of extracts in the batch.
     key: The key of the key-value pair to store in the target.
     value: The value of the key-value pair to store in the target.
     num_extracts: The total number of extracts to be merged in this target.
    """
    if isinstance(value, Mapping):
      if key not in target:
        target[key] = {}
      target = target[key]
      for k, v in value.items():
        merge_with_lists(target, index, k, v, num_extracts)
    else:
      # If key is newly found, we create a list with length of extracts,
      # so that every value of the i th extracts will go to the i th position.
      # And the extracts without this key will have value np.array([]).
      if key not in target:
        target[key] = [np.array([])] * num_extracts
      target[key][index] = value

  def merge_lists(target: types.Extracts) -> types.Extracts:
    """Converts target's leaves which are lists to batched np.array's, etc."""
    if isinstance(target, Mapping):
      result = {}
      for key, value in target.items():
        try:
          result[key] = merge_lists(value)
        except Exception as e:
          raise RuntimeError(
              f'Failed to convert value for key: {key} and value: {value}'
          ) from e
      return {k: merge_lists(v) for k, v in target.items()}
    elif (
        target
        and np.any(
            [isinstance(t, tf.compat.v1.SparseTensorValue) for t in target]
        )
        or np.any(
            [isinstance(target[0], types.SparseTensorValue) for _ in target]
        )
    ):
      t = tf.compat.v1.sparse_concat(
          0,
          [tf.sparse.expand_dims(to_tensorflow_tensor(t), 0) for t in target],
          expand_nonconcat_dim=True)
      return to_tensor_value(t)
    elif target and np.any(
        [isinstance(t, types.RaggedTensorValue) for t in target]):
      t = tf.concat(
          [tf.expand_dims(to_tensorflow_tensor(t), 0) for t in target], 0)
      return to_tensor_value(t)
    elif all(isinstance(t, np.ndarray)
             for t in target) and len({t.shape for t in target}) > 1:
      target = (t.squeeze() for t in target)
      return types.VarLenTensorValue.from_dense_rows(target)
    # If all value in the target are scalar numpy array, we stack them.
    # This is to avoid np.array([np.array(b'abc'), np.array(b'abcd')])
    # and stack to np.array([b'abc', b'abcd'])
    elif all(isinstance(t, np.ndarray) and t.shape == () for t in target):  # pylint: disable=g-explicit-bool-comparison
      return np.stack(target)
    elif all(t is None for t in target):
      return None
    else:
      # Compatibility shim for NumPy 1.24. See:
      # https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
      try:
        arr = np.array(target)
      except ValueError:
        arr = np.array(target, dtype=object)
      # Flatten values that were originally single item lists into a single list
      # e.g. [[1], [2], [3]] -> [1, 2, 3]
      if squeeze_two_dim_vector and len(arr.shape) == 2 and arr.shape[1] == 1:
        return arr.squeeze(axis=1)
      return arr

  result = {}
  num_extracts = len(extracts)
  for i, x in enumerate(extracts):
    if x:
      for k, v in x.items():
        merge_with_lists(result, i, k, v, num_extracts)
  return merge_lists(result)

model_construct_fn

model_construct_fn(
    eval_saved_model_path: Optional[str] = None,
    add_metrics_callbacks: Optional[
        List[AddMetricsCallbackType]
    ] = None,
    include_default_metrics: Optional[bool] = None,
    additional_fetches: Optional[List[str]] = None,
    blacklist_feature_fetches: Optional[List[str]] = None,
    tags: Optional[List[str]] = None,
    model_type: Optional[str] = TFMA_EVAL,
) -> Callable[[], Any]

Returns function for constructing shared models.

Source code in tensorflow_model_analysis/utils/model_util.py
def model_construct_fn(  # pylint: disable=invalid-name
    eval_saved_model_path: Optional[str] = None,
    add_metrics_callbacks: Optional[List[types.AddMetricsCallbackType]] = None,
    include_default_metrics: Optional[bool] = None,
    additional_fetches: Optional[List[str]] = None,
    blacklist_feature_fetches: Optional[List[str]] = None,
    tags: Optional[List[str]] = None,
    model_type: Optional[str] = constants.TFMA_EVAL,
) -> Callable[[], Any]:
  """Returns function for constructing shared models."""

  if tags is None:
   raise ValueError('Model tags must be specified.')

  def construct_fn():  # pylint: disable=invalid-name
    """Function for constructing shared models."""
    # If we are evaluating on TPU, initialize the TPU.
    # TODO(b/143484017): Add model warmup for TPU.
    if tf.saved_model.TPU in tags:
      tf.tpu.experimental.initialize_tpu_system()

    if model_type == constants.TF_KERAS:
      model = tf.keras.models.load_model(eval_saved_model_path)
    elif model_type == constants.TF_LITE:
      # The tf.lite.Interpreter is not thread-safe so we only load the model
      # file's contents and leave construction of the Interpreter up to the
      # PTransform using it.
      model_filename = os.path.join(eval_saved_model_path, _TFLITE_FILE_NAME)
      with tf.io.gfile.GFile(model_filename, 'rb') as model_file:
        model_bytes = model_file.read()

      # If a SavedModel is present in the same directory, load it as well.
      # This allows the SavedModel to be used for computing the
      # Transformed Features and Labels.
      if (tf.io.gfile.exists(
          os.path.join(eval_saved_model_path,
                       tf.saved_model.SAVED_MODEL_FILENAME_PB)) or
          tf.io.gfile.exists(
              os.path.join(eval_saved_model_path,
                           tf.saved_model.SAVED_MODEL_FILENAME_PBTXT))):
        model = tf.compat.v1.saved_model.load_v2(
            eval_saved_model_path, tags=tags)
        model.contents = model_bytes
      else:
        model = ModelContents(model_bytes)

    elif model_type == constants.TF_JS:
      # We invoke TFJS models via a subprocess call. So this call is no-op.
      return None
    else:
      model = tf.compat.v1.saved_model.load_v2(eval_saved_model_path, tags=tags)
    return model

  return construct_fn

unique_key

unique_key(
    key: str,
    current_keys: List[str],
    update_keys: Optional[bool] = False,
) -> str

Returns a unique key given a list of current keys.

If the key exists in current_keys then a new key with _1, _2, ..., etc appended will be returned, otherwise the key will be returned as passed.

PARAMETER DESCRIPTION
key

desired key name.

TYPE: str

current_keys

List of current key names.

TYPE: List[str]

update_keys

True to append the new key to current_keys.

TYPE: Optional[bool] DEFAULT: False

Source code in tensorflow_model_analysis/utils/util.py
def unique_key(key: str,
               current_keys: List[str],
               update_keys: Optional[bool] = False) -> str:
  """Returns a unique key given a list of current keys.

  If the key exists in current_keys then a new key with _1, _2, ..., etc
  appended will be returned, otherwise the key will be returned as passed.

  Args:
    key: desired key name.
    current_keys: List of current key names.
    update_keys: True to append the new key to current_keys.
  """
  index = 1
  k = key
  while k in current_keys:
    k = '%s_%d' % (key, index)
    index += 1
  if update_keys:
    current_keys.append(k)
  return k

update_eval_config_with_defaults

update_eval_config_with_defaults(
    eval_config: EvalConfig,
    maybe_add_baseline: Optional[bool] = None,
    maybe_remove_baseline: Optional[bool] = None,
    has_baseline: Optional[bool] = False,
    rubber_stamp: Optional[bool] = False,
) -> EvalConfig

Returns a new config with default settings applied.

a) Add or remove a model_spec according to "has_baseline". b) Fix the model names (model_spec.name) to tfma.CANDIDATE_KEY and tfma.BASELINE_KEY. c) Update the metrics_specs with the fixed model name.

PARAMETER DESCRIPTION
eval_config

Original eval config.

TYPE: EvalConfig

maybe_add_baseline

DEPRECATED. True to add a baseline ModelSpec to the config as a copy of the candidate ModelSpec that should already be present. This is only applied if a single ModelSpec already exists in the config and that spec doesn't have a name associated with it. When applied the model specs will use the names tfma.CANDIDATE_KEY and tfma.BASELINE_KEY. Only one of maybe_add_baseline or maybe_remove_baseline should be used.

TYPE: Optional[bool] DEFAULT: None

maybe_remove_baseline

DEPRECATED. True to remove a baseline ModelSpec from the config if it already exists. Removal of the baseline also removes any change thresholds. Only one of maybe_add_baseline or maybe_remove_baseline should be used.

TYPE: Optional[bool] DEFAULT: None

has_baseline

True to add a baseline ModelSpec to the config as a copy of the candidate ModelSpec that should already be present. This is only applied if a single ModelSpec already exists in the config and that spec doesn't have a name associated with it. When applied the model specs will use the names tfma.CANDIDATE_KEY and tfma.BASELINE_KEY. False to remove a baseline ModelSpec from the config if it already exists. Removal of the baseline also removes any change thresholds. Only one of has_baseline or maybe_remove_baseline should be used.

TYPE: Optional[bool] DEFAULT: False

rubber_stamp

True if this model is being rubber stamped. When a model is rubber stamped diff thresholds will be ignored if an associated baseline model is not passed.

TYPE: Optional[bool] DEFAULT: False

RAISES DESCRIPTION
RuntimeError

on missing baseline model for non-rubberstamp cases.

Source code in tensorflow_model_analysis/utils/config_util.py
def update_eval_config_with_defaults(
    eval_config: config_pb2.EvalConfig,
    maybe_add_baseline: Optional[bool] = None,
    maybe_remove_baseline: Optional[bool] = None,
    has_baseline: Optional[bool] = False,
    rubber_stamp: Optional[bool] = False) -> config_pb2.EvalConfig:
  """Returns a new config with default settings applied.

  a) Add or remove a model_spec according to "has_baseline".
  b) Fix the model names (model_spec.name) to tfma.CANDIDATE_KEY and
     tfma.BASELINE_KEY.
  c) Update the metrics_specs with the fixed model name.

  Args:
    eval_config: Original eval config.
    maybe_add_baseline: DEPRECATED. True to add a baseline ModelSpec to the
      config as a copy of the candidate ModelSpec that should already be
      present. This is only applied if a single ModelSpec already exists in the
      config and that spec doesn't have a name associated with it. When applied
      the model specs will use the names tfma.CANDIDATE_KEY and
      tfma.BASELINE_KEY. Only one of maybe_add_baseline or maybe_remove_baseline
      should be used.
    maybe_remove_baseline: DEPRECATED. True to remove a baseline ModelSpec from
      the config if it already exists. Removal of the baseline also removes any
      change thresholds. Only one of maybe_add_baseline or maybe_remove_baseline
      should be used.
    has_baseline: True to add a baseline ModelSpec to the config as a copy of
      the candidate ModelSpec that should already be present. This is only
      applied if a single ModelSpec already exists in the config and that spec
      doesn't have a name associated with it. When applied the model specs will
      use the names tfma.CANDIDATE_KEY and tfma.BASELINE_KEY. False to remove a
      baseline ModelSpec from the config if it already exists. Removal of the
      baseline also removes any change thresholds. Only one of has_baseline or
      maybe_remove_baseline should be used.
    rubber_stamp: True if this model is being rubber stamped. When a model is
      rubber stamped diff thresholds will be ignored if an associated baseline
      model is not passed.

  Raises:
    RuntimeError: on missing baseline model for non-rubberstamp cases.
  """
  if (not has_baseline and has_change_threshold(eval_config) and
      not rubber_stamp):
    # TODO(b/173657964): Raise an error instead of logging an error.
    raise RuntimeError(
        'There are change thresholds, but the baseline is missing. '
        'This is allowed only when rubber stamping (first run).')

  updated_config = config_pb2.EvalConfig()
  updated_config.CopyFrom(eval_config)
  # if user requests CIs but doesn't set method, use JACKKNIFE
  if (eval_config.options.compute_confidence_intervals.value and
      eval_config.options.confidence_intervals.method ==
      config_pb2.ConfidenceIntervalOptions.UNKNOWN_CONFIDENCE_INTERVAL_METHOD):
    updated_config.options.confidence_intervals.method = (
        config_pb2.ConfidenceIntervalOptions.JACKKNIFE)
  if maybe_add_baseline and maybe_remove_baseline:
    raise ValueError('only one of maybe_add_baseline and maybe_remove_baseline '
                     'should be used')
  if maybe_add_baseline or maybe_remove_baseline:
    logging.warning(
        """"maybe_add_baseline" and "maybe_remove_baseline" are deprecated,
        please use "has_baseline" instead.""")
    if has_baseline:
      raise ValueError(
          """"maybe_add_baseline" and "maybe_remove_baseline" are ignored if
          "has_baseline" is set.""")
  if has_baseline is not None:
    if has_baseline:
      maybe_add_baseline = True
    else:
      maybe_remove_baseline = True

  # Has a baseline model.
  if (maybe_add_baseline and len(updated_config.model_specs) == 1 and
      not updated_config.model_specs[0].name):
    baseline = updated_config.model_specs.add()
    baseline.CopyFrom(updated_config.model_specs[0])
    baseline.name = constants.BASELINE_KEY
    baseline.is_baseline = True
    updated_config.model_specs[0].name = constants.CANDIDATE_KEY
    logging.info(
        'Adding default baseline ModelSpec based on the candidate ModelSpec '
        'provided. The candidate model will be called "%s" and the baseline '
        'will be called "%s": updated_config=\n%s', constants.CANDIDATE_KEY,
        constants.BASELINE_KEY, updated_config)

  # Does not have a baseline.
  if maybe_remove_baseline:
    tmp_model_specs = []
    for model_spec in updated_config.model_specs:
      if not model_spec.is_baseline:
        tmp_model_specs.append(model_spec)
    del updated_config.model_specs[:]
    updated_config.model_specs.extend(tmp_model_specs)
    for metrics_spec in updated_config.metrics_specs:
      for metric in metrics_spec.metrics:
        if metric.threshold.ByteSize():
          metric.threshold.ClearField('change_threshold')
        for per_slice_threshold in metric.per_slice_thresholds:
          if per_slice_threshold.threshold.ByteSize():
            per_slice_threshold.threshold.ClearField('change_threshold')
        for cross_slice_threshold in metric.cross_slice_thresholds:
          if cross_slice_threshold.threshold.ByteSize():
            cross_slice_threshold.threshold.ClearField('change_threshold')
      for threshold in metrics_spec.thresholds.values():
        if threshold.ByteSize():
          threshold.ClearField('change_threshold')
      for per_slice_thresholds in metrics_spec.per_slice_thresholds.values():
        for per_slice_threshold in per_slice_thresholds.thresholds:
          if per_slice_threshold.threshold.ByteSize():
            per_slice_threshold.threshold.ClearField('change_threshold')
      for cross_slice_thresholds in metrics_spec.cross_slice_thresholds.values(
      ):
        for cross_slice_threshold in cross_slice_thresholds.thresholds:
          if cross_slice_threshold.threshold.ByteSize():
            cross_slice_threshold.threshold.ClearField('change_threshold')
    logging.info(
        'Request was made to ignore the baseline ModelSpec and any change '
        'thresholds. This is likely because a baseline model was not provided: '
        'updated_config=\n%s', updated_config)

  if not updated_config.model_specs:
    updated_config.model_specs.add()

  model_names = []
  for spec in updated_config.model_specs:
    model_names.append(spec.name)
  if len(model_names) == 1 and model_names[0]:
    logging.info(
        'ModelSpec name "%s" is being ignored and replaced by "" because a '
        'single ModelSpec is being used', model_names[0])
    updated_config.model_specs[0].name = ''
    model_names = ['']
  for spec in updated_config.metrics_specs:
    if not spec.model_names:
      spec.model_names.extend(model_names)
    elif len(model_names) == 1:
      del spec.model_names[:]
      spec.model_names.append('')

  return updated_config

verify_and_update_eval_shared_models

verify_and_update_eval_shared_models(
    eval_shared_model: Optional[
        MaybeMultipleEvalSharedModels
    ],
) -> Optional[List[EvalSharedModel]]

Verifies eval shared models and normnalizes to produce a single list.

The output is normalized such that if a list or dict contains a single entry, the model name will always be empty.

PARAMETER DESCRIPTION
eval_shared_model

None, a single model, a list of models, or a dict of models keyed by model name.

TYPE: Optional[MaybeMultipleEvalSharedModels]

RETURNS DESCRIPTION
Optional[List[EvalSharedModel]]

A list of models or None.

Source code in tensorflow_model_analysis/utils/model_util.py
def verify_and_update_eval_shared_models(
    eval_shared_model: Optional[types.MaybeMultipleEvalSharedModels]
) -> Optional[List[types.EvalSharedModel]]:
  """Verifies eval shared models and normnalizes to produce a single list.

  The output is normalized such that if a list or dict contains a single entry,
  the model name will always be empty.

  Args:
    eval_shared_model: None, a single model, a list of models, or a dict of
      models keyed by model name.

  Returns:
    A list of models or None.

  Raises:
    ValueError if dict is passed and keys don't match model names or a
    multi-item list is passed without model names.
  """
  if not eval_shared_model:
    return None
  eval_shared_models = []
  if isinstance(eval_shared_model, dict):
    for k, v in eval_shared_model.items():
      if v.model_name and k and k != v.model_name:
        raise ValueError('keys for EvalSharedModel dict do not match '
                         'model_names: dict={}'.format(eval_shared_model))
      if not v.model_name and k:
        v = v._replace(model_name=k)
      eval_shared_models.append(v)
  elif isinstance(eval_shared_model, list):
    # Ensure we don't modify the input list when updating model_name, below.
    eval_shared_models = eval_shared_model.copy()
  else:
    eval_shared_models = [eval_shared_model]
  if len(eval_shared_models) > 1:
    for v in eval_shared_models:
      if not v.model_name:
        raise ValueError(
            'model_name is required when passing multiple EvalSharedModels: '
            'eval_shared_models={}'.format(eval_shared_models))
  # To maintain consistency between settings where single models are used,
  # always use '' as the model name regardless of whether a name is passed.
  elif len(eval_shared_models) == 1 and eval_shared_models[0].model_name:
    eval_shared_models[0] = eval_shared_models[0]._replace(model_name='')
  # Normalizes model types to TFMA_EVAL when appropriate.
  for i, model in enumerate(eval_shared_models):
    assert isinstance(model, types.EvalSharedModel)
  return eval_shared_models  # pytype: disable=bad-return-type  # py310-upgrade

verify_eval_config

verify_eval_config(
    eval_config: EvalConfig,
    baseline_required: Optional[bool] = None,
)

Verifies eval config.

Source code in tensorflow_model_analysis/utils/config_util.py
def verify_eval_config(eval_config: config_pb2.EvalConfig,
                       baseline_required: Optional[bool] = None):
  """Verifies eval config."""
  if not eval_config.model_specs:
    raise ValueError(
        'At least one model_spec is required: eval_config=\n{}'.format(
            eval_config))

  model_specs_by_name = {}
  baseline = None
  for spec in eval_config.model_specs:
    if spec.label_key and spec.label_keys:
      raise ValueError('only one of label_key or label_keys should be used at '
                       'a time: model_spec=\n{}'.format(spec))
    if spec.prediction_key and spec.prediction_keys:
      raise ValueError(
          'only one of prediction_key or prediction_keys should be used at '
          'a time: model_spec=\n{}'.format(spec))
    if spec.example_weight_key and spec.example_weight_keys:
      raise ValueError(
          'only one of example_weight_key or example_weight_keys should be '
          'used at a time: model_spec=\n{}'.format(spec))
    if spec.name in eval_config.model_specs:
      raise ValueError(
          'more than one model_spec found for model "{}": {}'.format(
              spec.name, [spec, model_specs_by_name[spec.name]]))
    model_specs_by_name[spec.name] = spec
    if spec.is_baseline:
      if baseline is not None:
        raise ValueError('only one model_spec may be a baseline, found: '
                         '{} and {}'.format(spec, baseline))
      baseline = spec

  if len(model_specs_by_name) > 1 and '' in model_specs_by_name:
    raise ValueError('A name is required for all ModelSpecs when multiple '
                     'models are used: eval_config=\n{}'.format(eval_config))

  if baseline_required and not baseline:
    raise ValueError(
        'A baseline ModelSpec is required: eval_config=\n{}'.format(
            eval_config))

  # Raise exception if per_slice_thresholds has no slicing_specs.
  for metric_spec in eval_config.metrics_specs:
    for name, per_slice_thresholds in metric_spec.per_slice_thresholds.items():
      for per_slice_threshold in per_slice_thresholds.thresholds:
        if not per_slice_threshold.slicing_specs:
          raise ValueError(
              'slicing_specs must be set on per_slice_thresholds but found '
              f'per_slice_threshold=\n{per_slice_threshold}\n'
              f'for metric name {name} in metric_spec:\n{metric_spec}'
          )
    for metric_config in metric_spec.metrics:
      for per_slice_threshold in metric_config.per_slice_thresholds:
        if not per_slice_threshold.slicing_specs:
          raise ValueError(
              'slicing_specs must be set on per_slice_thresholds but found '
              f'per_slice_threshold=\n{per_slice_threshold}\n'
              f'for metric config:\t{metric_config}'
          )