Download presentation
Presentation is loading. Please wait.
1
ONNX Training Discussion
Wei-Sheng Chin AI Frameworks Microsoft
2
Training of Generative Adversarial Networks (GANs)
GANβs training problem is a min-max game between two players. A generator mode πΊ ΞΈ with trainable parameter π. A discriminator π· π€ with trainable parameter π€. A common optimization problem to train GANs (cited by 1.4k papers). πππ π πππ₯ π€ πΏ π, π€ πΏ π, π€ =β πΈ π₯ π· π₯ + πΈ π₯ π· π₯ βπ πΈ π₯ β π₯ π· π€ π₯ 2 β1 2 π₯ =πΊ π§ , π§ is sampled from a pre-defined distribution, π(π§). π₯ =π π₯ + 1βπ π₯ Important ingredients in the objective function. Itβs not a simple minimization problem. Itβs a min-max problem! To compute the β π€ πΏ, recursive differentiation like β π€ β π₯ π· π€ is needed.
3
Issues in TrainingInfoProto
Itβs designed only for minimization problem, so min-max problems are not supported. It only contains one call (as gradient_binding) to compute gradient, but to train GANs, we need more for β π€ β π₯ π· π€ .
4
Ingredients Needed to Support GANsβ Training
Gradient operator which generates gradient tensors. π¦, π₯β Grad β ππ¦ ππ₯ The gradient tensors can be fed into another gradient operator to compute the 2nd-order or high-order differentiation. ππ¦ ππ₯ , π₯β Grad β π 2 π¦ π π₯ 2
5
A Possible Signature of Gradient Operator
Attributes xs: Names of the π differentiated tensors. y: Name of target tensor. Inputs Values of the source tensors. The i-th tensor in βInputsβ would be bound to the i-th tensor in βxsβ. Outputs [ ππ¦ π π₯ 1 , ππ¦ π π₯ 2 , β¦, ππ¦ π π₯ π ].
6
A Possible Signature of Gradient Operator (Contβd)
Attributes xs: Names of the π differentiated tensors. y: Name of target tensor. Attributes are used to identify a sub-graph. To compute ππΏ π π» 2 , we need xs = [βH1β, βH2β, βH3β] and y = βLβ. Inputs and outputs will be [H1, H2, H3] and [ ππΏ π π» 1 , ππΏ π π» 2 , ππΏ π π» 3 ], respectively. Note that we cannot only compute ππΏ π π» 2 because L depends on H1, H2, and H3.
7
A Possible Signature of Gradient Operator (Contβd)
Attributes xs = [βH1β, βH2β, βH3β]. y = βLβ. Procedure to compute outputs. Copy the identified sub-graph to an isolated space. Bind the indicated inputs to the isolated graph. Conduct a forward pass on the isolated graph. Conduct a backward pass on the isolated graph. Bind the gradient tensors to the output tensors. Deallocate the isolated graph.
8
Another Issue in TrainingInfoProto
To solve the min-max problems in all GANs, one training iteration may contain two isolated assignments. π€βπ€+ ππΏ(π,π€) ππ€ πβπβ ππΏ(π,π€) ππ However, we only have one βupdate_bindingβ in the TrainingInfoProto.
9
Possible Changes Related to TrainingInfoProto
Change loss and optimizer to a list of NodeProto (they are currently single NodeProtoβs) or merge loss and optimizer into a single list of NodeProtoβs. Introduce Update operator or make ModelProto.training_info a list of TrainingInfoProto to capture assignments in different training stages.
10
An Overall Proposal Introduce Gradient Operator Attributes Inputs
xs: Names of the π differentiated tensors. y: Name of target tensor. Inputs Values of the source tensors. The i-th tensor in βInputsβ would be bound to the i-th tensor in βxsβ. Outputs [ ππ¦ π π₯ 1 , ππ¦ π π₯ 2 , β¦, ππ¦ π π₯ π ]. All outputs are optional. User can use empty string to indicate unnecessary ones.
11
An Overall Proposal (Contβd)
Merge loss and optimizer NodeProtoβs into a function. message TrainingInfoProto { repeated TensorProto initializer = 1; optional FunctionProto algorithm = 2; // Loss, optimizer, etc. repeated StringStringEntryProto update_binding = 3; }
12
An Overall Proposal (Contβd)
Training algorithm can have multiple stages. message ModelProto { β¦ repeated FunctionProto function = 9; repeated TrainingInfoProto training_info = 21; } Training graph and inference are almost isolated. Inside TrainingInfoProto, we can only access initializers in the inference graph. Conducting one recommended iteration means sequentially running training_info[0], training_info[1], β¦, training_info[-1] once. TrainingInfoProto ο TrainingStageProto
13
Store GAN with A Training Algorithm
ModelProto.function (type: a list of FunctionProtoβs). Functions can be called in ModelProto.graph and ModelProto.training_info. There are two functions; one for generator (denoted by πΊ π ) and the other one for discriminator (denoted by π· π€ ). The parameter π is stored in ModelProto.graph.initializer. The parameter π€ is stored in ModelProto.training_info.initializer. π€ is training-only. Backward pass is similar to forward pass; for example, gradient can happen in the forward pass. Not to affect inference graph --- add Call-to-Graph
14
Store GAN with A Training Algorithm (Contβd)
ModelProto.graph ModelProto.graph.initializer is the initializers used in πΊ π . ModelProto.graph.input is the inputs of πΊ π . It only includes one NodeProto which is a call to πΊ π . ModelProto.graph.output produces outputs of πΊ π .
15
Store GAN with A Training Algorithm (Contβd)
ModelProto.training_info[0] Training algorithm to update πΊ π βs initializers (aka π). Given some data points π₯, a Random function, and a constant 0<π<1. π§βRandom() π₯ βCall<G>(π§,π) // G is a FunctionProto and π is stored in ModelProto.graph.initializer. π₯ βππ₯+ 1βπ π₯ πΏ 1 βCall<D>( π₯ ,π€) // D is a FunctionProto and π€ is stored in ModelProto.graph.initializer. πΏ 2 βCall<D>(π₯,π€) πΏ 3 βCall<D>( π₯ ,π€) [ πΊ π₯ ,?]βGrad( π₯ ,π€, xs=[" π₯ ","π€"], y=β πΏ 3 β) // β?β means the output is not generated. πΊ π₯ 2 β πΊ π₯ 2 πΏβ πΏ 1 β πΏ 2 +π πΊ π₯ 2 β1 2 [?,?,?, πΊ π€ ,?]βGrad(inputs=[π₯, π₯ , π₯ ,π€,π], xs=["π₯"," π₯ "," π₯ ","π€","π"], y="πΏβ) π€ πππ€ βπ€β πΊ π€ The pair ( π€ πππ€ ,π€) presents in update_binding. Both of \theta and w are stored in the inference graphβs initializer list (a global memory shared by inference and training) because they are used in multiple training stages.
16
Store GAN with A Training Algorithm (Contβd)
ModelProto.training_info[1] Training algorithm to update π· π€ βs initializers (aka π€). Given some data points π₯, a Random function, and a constant 0<π<1. π§βRandom() π₯ βCall<G>(π§,π) // πΊ is a FunctionProto and π is stored in ModelProto.graph.initializer. πΏβ Call<D>( π₯ ,π€) // π· is a FunctionProto and π€ is stored in ModelProto.graph.initializer. [?,πΊ,?]βGrad(π§, xs=["π§","π","π€"], y="πΏ") π πππ€ βπ+πΊ The pair ( π πππ€ ,π) presents in update_binding.
17
Thank you!
Similar presentations
© 2025 SlidePlayer.com. Inc.
All rights reserved.