Skip to content
This repository was archived by the owner on Dec 29, 2022. It is now read-only.
This repository was archived by the owner on Dec 29, 2022. It is now read-only.

seq2seq checkpoint restore for transfer learning #356

@nweir127

Description

@nweir127

I am using code built on top of train.py and infer.py (files unchanged) from the seq2seq tutorial. I want to do transfer learning/loading from checkpoints but am unfamiliar with the tf.contrib.learn.Estimator and seq2seq.contrib.experiment environment.

I basically want to incorporate the checkpoint load step from infer.py into training:

saver = tf.train.Saver()
  checkpoint_path = FLAGS.checkpoint_path
  if not checkpoint_path:
    checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)

  def session_init_op(_scaffold, sess):
    saver.restore(sess, checkpoint_path)
    tf.logging.info("Restored model from %s", checkpoint_path)

How/where along the pipeline should I be inserting the script?

def create_experiment(output_dir):
  """
  Creates a new Experiment instance.

  Args:
    output_dir: Output directory for model checkpoints and summaries.
  """

  config = run_config.RunConfig(
      tf_random_seed=FLAGS.tf_random_seed,
      save_checkpoints_secs=FLAGS.save_checkpoints_secs,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      keep_checkpoint_max=FLAGS.keep_checkpoint_max,
      keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
      gpu_memory_fraction=FLAGS.gpu_memory_fraction)
  config.tf_config.gpu_options.allow_growth = FLAGS.gpu_allow_growth
  config.tf_config.log_device_placement = FLAGS.log_device_placement

  train_options = training_utils.TrainOptions(
      model_class=FLAGS.model,
      model_params=FLAGS.model_params)
  # On the main worker, save training options
  if config.is_chief:
    gfile.MakeDirs(output_dir)
    train_options.dump(output_dir)

  bucket_boundaries = None
  if FLAGS.buckets:
    bucket_boundaries = list(map(int, FLAGS.buckets.split(",")))

  # Training data input pipeline
  train_input_pipeline = input_pipeline.make_input_pipeline_from_def(
      def_dict=FLAGS.input_pipeline_train,
      mode=tf.contrib.learn.ModeKeys.TRAIN)

  # Create training input function
  train_input_fn = training_utils.create_input_fn(
      pipeline=train_input_pipeline,
      batch_size=FLAGS.batch_size,
      bucket_boundaries=bucket_boundaries,
      scope="train_input_fn")

  # Development data input pipeline
  dev_input_pipeline = input_pipeline.make_input_pipeline_from_def(
      def_dict=FLAGS.input_pipeline_dev,
      mode=tf.contrib.learn.ModeKeys.EVAL,
      shuffle=False, num_epochs=1)

  # Create eval input function
  eval_input_fn = training_utils.create_input_fn(
      pipeline=dev_input_pipeline,
      batch_size=FLAGS.batch_size,
      allow_smaller_final_batch=True,
      scope="dev_input_fn")


  def model_fn(features, labels, params, mode):
    """Builds the model graph"""
    model = _create_from_dict({
        "class": train_options.model_class,
        "params": train_options.model_params
    }, models, mode=mode)
    return model(features, labels, params)

  estimator = tf.contrib.learn.Estimator(
      model_fn=model_fn,
      model_dir=output_dir,
      config=config,
      params=FLAGS.model_params)

  # Create hooks
  train_hooks = []
  for dict_ in FLAGS.hooks:
    hook = _create_from_dict(
        dict_, hooks,
        model_dir=estimator.model_dir,
        run_config=config)
    train_hooks.append(hook)

  # Create metrics
  eval_metrics = {}
  for dict_ in FLAGS.metrics:
    metric = _create_from_dict(dict_, metric_specs)
    eval_metrics[metric.name] = metric

  saver = tf.train.Saver()
  checkpoint_path = FLAGS.checkpoint_path
  if not checkpoint_path:
      checkpoint_path = tf.train.latest_checkpoint(FLAGS.model_dir)

  saver.restore(sess, checkpoint_path)

  ## what is PatchedExperiment

  experiment = PatchedExperiment(
      
      estimator=estimator,
      train_input_fn=train_input_fn,
      eval_input_fn=eval_input_fn,
      min_eval_frequency=FLAGS.eval_every_n_steps,
      train_steps=FLAGS.train_steps,
      eval_steps=None,
      eval_metrics=eval_metrics,
      train_monitors=train_hooks)

  return experiment

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions