ONNX Training Discussion Wei-Sheng Chin AI Frameworks Microsoft
Training of Generative Adversarial Networks (GANs) GAN’s training problem is a min-max game between two players. A generator mode 𝐺 θ with trainable parameter 𝜃. A discriminator 𝐷 𝑤 with trainable parameter 𝑤. A common optimization problem to train GANs (cited by 1.4k papers). 𝑚𝑖𝑛 𝜃 𝑚𝑎𝑥 𝑤 𝐿 𝜃, 𝑤 𝐿 𝜃, 𝑤 =− 𝐸 𝑥 𝐷 𝑥 + 𝐸 𝑥 𝐷 𝑥 −𝜆 𝐸 𝑥 ∇ 𝑥 𝐷 𝑤 𝑥 2 −1 2 𝑥 =𝐺 𝑧 , 𝑧 is sampled from a pre-defined distribution, 𝑝(𝑧). 𝑥 =𝜖 𝑥 + 1−𝜖 𝑥 Important ingredients in the objective function. It’s not a simple minimization problem. It’s a min-max problem! To compute the ∇ 𝑤 𝐿, recursive differentiation like ∇ 𝑤 ∇ 𝑥 𝐷 𝑤 is needed.
Issues in TrainingInfoProto It’s designed only for minimization problem, so min-max problems are not supported. It only contains one call (as gradient_binding) to compute gradient, but to train GANs, we need more for ∇ 𝑤 ∇ 𝑥 𝐷 𝑤 .
Ingredients Needed to Support GANs’ Training Gradient operator which generates gradient tensors. 𝑦, 𝑥→ Grad → 𝑑𝑦 𝑑𝑥 The gradient tensors can be fed into another gradient operator to compute the 2nd-order or high-order differentiation. 𝑑𝑦 𝑑𝑥 , 𝑥→ Grad → 𝑑 2 𝑦 𝑑 𝑥 2
A Possible Signature of Gradient Operator Attributes xs: Names of the 𝑛 differentiated tensors. y: Name of target tensor. Inputs Values of the source tensors. The i-th tensor in “Inputs” would be bound to the i-th tensor in “xs”. Outputs [ 𝜕𝑦 𝜕 𝑥 1 , 𝜕𝑦 𝜕 𝑥 2 , …, 𝜕𝑦 𝜕 𝑥 𝑛 ].
A Possible Signature of Gradient Operator (Cont’d) Attributes xs: Names of the 𝑛 differentiated tensors. y: Name of target tensor. Attributes are used to identify a sub-graph. To compute 𝜕𝐿 𝜕 𝐻 2 , we need xs = [“H1”, “H2”, “H3”] and y = “L”. Inputs and outputs will be [H1, H2, H3] and [ 𝜕𝐿 𝜕 𝐻 1 , 𝜕𝐿 𝜕 𝐻 2 , 𝜕𝐿 𝜕 𝐻 3 ], respectively. Note that we cannot only compute 𝜕𝐿 𝜕 𝐻 2 because L depends on H1, H2, and H3.
A Possible Signature of Gradient Operator (Cont’d) Attributes xs = [“H1”, “H2”, “H3”]. y = “L”. Procedure to compute outputs. Copy the identified sub-graph to an isolated space. Bind the indicated inputs to the isolated graph. Conduct a forward pass on the isolated graph. Conduct a backward pass on the isolated graph. Bind the gradient tensors to the output tensors. Deallocate the isolated graph.
Another Issue in TrainingInfoProto To solve the min-max problems in all GANs, one training iteration may contain two isolated assignments. 𝑤←𝑤+ 𝜕𝐿(𝜃,𝑤) 𝜕𝑤 𝜃←𝜃− 𝜕𝐿(𝜃,𝑤) 𝜕𝜃 However, we only have one “update_binding” in the TrainingInfoProto.
Possible Changes Related to TrainingInfoProto Change loss and optimizer to a list of NodeProto (they are currently single NodeProto’s) or merge loss and optimizer into a single list of NodeProto’s. Introduce Update operator or make ModelProto.training_info a list of TrainingInfoProto to capture assignments in different training stages.
An Overall Proposal Introduce Gradient Operator Attributes Inputs xs: Names of the 𝑛 differentiated tensors. y: Name of target tensor. Inputs Values of the source tensors. The i-th tensor in “Inputs” would be bound to the i-th tensor in “xs”. Outputs [ 𝜕𝑦 𝜕 𝑥 1 , 𝜕𝑦 𝜕 𝑥 2 , …, 𝜕𝑦 𝜕 𝑥 𝑛 ]. All outputs are optional. User can use empty string to indicate unnecessary ones.
An Overall Proposal (Cont’d) Merge loss and optimizer NodeProto’s into a function. message TrainingInfoProto { repeated TensorProto initializer = 1; optional FunctionProto algorithm = 2; // Loss, optimizer, etc. repeated StringStringEntryProto update_binding = 3; }
An Overall Proposal (Cont’d) Training algorithm can have multiple stages. message ModelProto { … repeated FunctionProto function = 9; repeated TrainingInfoProto training_info = 21; } Training graph and inference are almost isolated. Inside TrainingInfoProto, we can only access initializers in the inference graph. Conducting one recommended iteration means sequentially running training_info[0], training_info[1], …, training_info[-1] once. TrainingInfoProto TrainingStageProto
Store GAN with A Training Algorithm ModelProto.function (type: a list of FunctionProto’s). Functions can be called in ModelProto.graph and ModelProto.training_info. There are two functions; one for generator (denoted by 𝐺 𝜃 ) and the other one for discriminator (denoted by 𝐷 𝑤 ). The parameter 𝜃 is stored in ModelProto.graph.initializer. The parameter 𝑤 is stored in ModelProto.training_info.initializer. 𝑤 is training-only. Backward pass is similar to forward pass; for example, gradient can happen in the forward pass. Not to affect inference graph --- add Call-to-Graph
Store GAN with A Training Algorithm (Cont’d) ModelProto.graph ModelProto.graph.initializer is the initializers used in 𝐺 𝜃 . ModelProto.graph.input is the inputs of 𝐺 𝜃 . It only includes one NodeProto which is a call to 𝐺 𝜃 . ModelProto.graph.output produces outputs of 𝐺 𝜃 .
Store GAN with A Training Algorithm (Cont’d) ModelProto.training_info[0] Training algorithm to update 𝐺 𝜃 ’s initializers (aka 𝜃). Given some data points 𝑥, a Random function, and a constant 0<𝜖<1. 𝑧←Random() 𝑥 ←Call<G>(𝑧,𝜃) // G is a FunctionProto and 𝜃 is stored in ModelProto.graph.initializer. 𝑥 ←𝜖𝑥+ 1−𝜖 𝑥 𝐿 1 ←Call<D>( 𝑥 ,𝑤) // D is a FunctionProto and 𝑤 is stored in ModelProto.graph.initializer. 𝐿 2 ←Call<D>(𝑥,𝑤) 𝐿 3 ←Call<D>( 𝑥 ,𝑤) [ 𝐺 𝑥 ,?]←Grad( 𝑥 ,𝑤, xs=[" 𝑥 ","𝑤"], y=“ 𝐿 3 ”) // “?” means the output is not generated. 𝐺 𝑥 2 ← 𝐺 𝑥 2 𝐿← 𝐿 1 − 𝐿 2 +𝜆 𝐺 𝑥 2 −1 2 [?,?,?, 𝐺 𝑤 ,?]←Grad(inputs=[𝑥, 𝑥 , 𝑥 ,𝑤,𝜃], xs=["𝑥"," 𝑥 "," 𝑥 ","𝑤","𝜃"], y="𝐿“) 𝑤 𝑛𝑒𝑤 ←𝑤− 𝐺 𝑤 The pair ( 𝑤 𝑛𝑒𝑤 ,𝑤) presents in update_binding. Both of \theta and w are stored in the inference graph’s initializer list (a global memory shared by inference and training) because they are used in multiple training stages.
Store GAN with A Training Algorithm (Cont’d) ModelProto.training_info[1] Training algorithm to update 𝐷 𝑤 ’s initializers (aka 𝑤). Given some data points 𝑥, a Random function, and a constant 0<𝜖<1. 𝑧←Random() 𝑥 ←Call<G>(𝑧,𝜃) // 𝐺 is a FunctionProto and 𝜃 is stored in ModelProto.graph.initializer. 𝐿← Call<D>( 𝑥 ,𝑤) // 𝐷 is a FunctionProto and 𝑤 is stored in ModelProto.graph.initializer. [?,𝐺,?]←Grad(𝑧, xs=["𝑧","𝜃","𝑤"], y="𝐿") 𝜃 𝑛𝑒𝑤 ←𝜃+𝐺 The pair ( 𝜃 𝑛𝑒𝑤 ,𝜃) presents in update_binding.
Thank you!