Code for VeLO 1: Training Versatile Learned Optimizers by Scaling Up

阿里云国内75折 回扣 微信号:monov8
阿里云国际,腾讯云国际,低至75折。AWS 93折 免费开户实名账号 代冲值 优惠多多 微信号:monov8 飞机:@monov6

Code for VeLO 1: Training Versatile Learned Optimizers by Scaling Up

这篇文章将介绍一下怎么用VeLO进行训练。

这篇文章基于https://colab.research.google.com/drive/1-ms12IypE-EdDSNjhFMdRdBbMnH94zpH#scrollTo=RQBACAPQZyB-将介绍使用learned optimizer in the VeLO family:

Accelerator Setup、依赖安装和导入

# 设置Accelerator的类型一般在实验室中只有GPU
Accelerator_Type = 'GPU' #@param ["GPU", "TPU", "CPU"]

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)


# install lopt
# learned_optimization 这个库中包含了
!pip install git+https://github.com/google/learned_optimization.git

# jax 是 TensorFlow 的一个简化库名为 JAX结合 Autograd 和 XLA可以支持部分 TensorFlow 的功能但是比 TensorFlow 更加简洁易用。
import jax
if Accelerator_Type == 'TPU':
  import jax.tools.colab_tpu
  jax.tools.colab_tpu.setup_tpu()

from learned_optimization.tasks.fixed import image_mlp
from learned_optimization.research.general_lopt import prefab
from learned_optimization import eval_training

from matplotlib import pylab as plt
from learned_optimization import notebook_utils as nu
import numpy as onp


from learned_optimization.baselines import utils
import os
# use the precomputed baselines folder from gcp for loading baseline training curves
# 这句话我不是很清楚是什么含义 emm
os.environ["LOPT_BASELINE_ARCHIVES_DIR"] = "gs://gresearch/learned_optimization/opt_archives/"

使用Optax style的优化器

jax有自己的一个示例版优化库optimizers不过这个库非常的小都没实现学习率训练计划schedule当然也可以自己写一个函数learning_rate_fn(steps)然后作为参数传入optimizers.sgd(step_size=learning_rate_fn)即可。

如果自己写比较麻烦就可以用optax库。https://zhuanlan.zhihu.com/p/545561011

import optax
# defining an optimizer that targets 1000 training steps
NUM_STEPS = 1000 # 这里是制定优化器要执行的步数
opt = prefab.optax_lopt(NUM_STEPS)  # 定义优化器

定义和执行一个简单的训练循环

# Learned_optimization contains a handful of predefined tasks.  These tasks
# wrap the model initialization and dataset definitions in one convenient
# object.  Here, we initialize a simple MLP for the fashionmnist dataset.
# 一个手动预定义的task包装了MLP model + fashionmnist dataset
task = image_mlp.ImageMLP_FashionMnist8_Relu32()

# We initialize the underlying MLP and collect its state using its init
# function.  Under the hood, this is really just initializing a haiku model
# as seen here (https://github.com/google/learned_optimization/blob/main/learned_optimization/tasks/fixed/image_mlp.py#L58).
# 初始化这个模型
key = jax.random.PRNGKey(0)
params = task.init(key)

# finally, we initialize the optimizer with the model state:
# 使用model的state来初始化优化器
opt_state = opt.init(params)

# 在训练循环中我们只需要这么一个update函数
# For a training loop, all we need is an update function.  This update function
# takes existing optimizer state优化器参数, model params模型参数, training data训练数据, and randomness随机数
# as args, and returns new optimizer state, new model params, and the loss.
# import jax 

@jax.jit
def update(opt_state, params, data, key): 
  """Simple training update function.
  Args:
    opt_state: Optimizer state
    params: Model parameter weights
    data: Training data
    key: Jax randomness
  
  Returns:
    A tuple of updated optimizer state, model state, and the current loss.
    返回一个元组优化器的参数、模型的参数、还有当前的loss训练数据已经用了不需要返回"""
  l, g = jax.value_and_grad(task.loss)(params, key, data)

  # 我猜测这里的优化器应该是默认frozen的然后
  updates, opt_state = opt.update(g, opt_state, params=params, extra_args={"loss": l}) 
  params = optax.apply_updates(params, updates)  # 对模型的参数进行更新
  return opt_state, params, l


# a simple training loop

losses = []
for i in range(NUM_STEPS):
  batch = next(task.datasets.train) # 从训练集中拿出数据出来
  key1, key = jax.random.split(key) # 随机数的处理
  opt_state, params, l = update(opt_state, params, batch, key1)  # 执行update函数
  losses.append(l) 

绘制一下loss的图像

# here we visualize the loss during training
plt.plot(losses)

image-20230115211105491

阿里云国内75折 回扣 微信号:monov8
阿里云国际,腾讯云国际,低至75折。AWS 93折 免费开户实名账号 代冲值 优惠多多 微信号:monov8 飞机:@monov6