Module: TensorStream::Train::LearningRateDecay
- Included in:
- TensorStream::Trainer
- Defined in:
- lib/tensor_stream/train/learning_rate_decay.rb
Constant Summary
Constants included from Ops
Ops::FLOATING_POINT_TYPES, Ops::INTEGER_TYPES, Ops::NUMERIC_TYPES
Instance Method Summary collapse
-
#exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase: false, name: nil) ⇒ Object
Applies exponential decay to the learning rate.
Methods included from Ops
#abs, #acos, #add_n, #asin, #assert_equal, #atan, #broadcast_gradient_args, #case, #cast, #cast_axis, #check_numerics, #clip_by_norm, #concat, #cond, #constant_initializer, #cumprod, #dynamic_partition, #exp, #expand_dims, #eye, #floor_div, #gather, #glorot_uniform_initializer, #gradients, #identity, #index, #invert_permutation, #log, #log1p, #logical_and, #maximum, #minimum, #multiply, #negative, #not_equal, #ones, #ones_initializer, #ones_like, #pack, #pad, #print, #random_normal, #random_uniform_initializer, #reciprocal, #reduce, #reduce_mean, #reshape, #sec, #setdiff1d, #shape_n, #slice, #split, #sqrt, #square, #squared_difference, #squeeze, #stack, #stop_gradient, #transpose, #truncated_normal, #unpack, #unstack, #where, #zeros_initializer, #zeros_like
Methods included from OpStub
#add, #argmax, #argmin, #ceil, #cos, #div, #equal, #expand_dims, #fill, #floor, #floor_div, #greater, #greater_equal, #less, #less_equal, #log, #mat_mul, #max, #min, #mod, #mul, #negate, #not_equal, #ones_like, #pow, #prod, #random_uniform, #range, #rank, #reshape, #round, #rsqrt, #shape, #sigmoid, #sign, #sin, #size, #strided_slice, #sub, #sum, #tan, #tanh, #tile, #top_k, #zeros
Methods included from OpHelper
#_op, #cons, #format_source, #fp_type?, #i_cons, #i_op, #i_var, #int_type?, #reduced_shape, #shape_eval, #shape_full_specified, #shapes_fully_specified_and_equal
Methods included from Utils
#__v_scope_name, #apply_data_type_coercion, #assign, #check_allowed_types, #check_data_types, #check_if_dense, #colocate_with, #constant, #control_dependencies, #convert_to_tensor, #device, #disable_eager_execution, #dynamic_stitch, #enable_eager_execution, #executing_eagerly?, #float32, #get_collection, #get_default_graph, #get_variable, #get_variable_scope, #global_variables_initializer, #graph, #group, #image, #layers, #list_local_devices, #math, #name_scope, #placeholder, #program, #reset_default_graph, #session, #set_random_seed, #train, #trainable_variables, #variable, #variable_scope
Instance Method Details
#exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase: false, name: nil) ⇒ Object
Applies exponential decay to the learning rate
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
# File 'lib/tensor_stream/train/learning_rate_decay.rb', line 12 def exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase: false, name: nil) raise TensorStream::ValueError, "global_step is required for exponential_decay." if global_step.nil? name_scope(name, default: "ExponentialDecay", values: [learning_rate, global_step, decay_steps, decay_rate]) do learning_rate = convert_to_tensor(learning_rate, name: "learning_rate") data_type = learning_rate.data_type decay_steps = cast(decay_steps, data_type) decay_rate = cast(decay_rate, data_type) global_step_recomp = cast(global_step, data_type) p = global_step_recomp / decay_steps p = floor(p) if staircase multiply(learning_rate, pow(decay_rate, p), name: name) end end |