Presentation is loading. Please wait.

Presentation is loading. Please wait.

ONNX Training Discussion

Similar presentations


Presentation on theme: "ONNX Training Discussion"β€” Presentation transcript:

1 ONNX Training Discussion
Wei-Sheng Chin AI Frameworks Microsoft

2 Training of Generative Adversarial Networks (GANs)
GAN’s training problem is a min-max game between two players. A generator mode 𝐺 ΞΈ with trainable parameter πœƒ. A discriminator 𝐷 𝑀 with trainable parameter 𝑀. A common optimization problem to train GANs (cited by 1.4k papers). π‘šπ‘–π‘› πœƒ π‘šπ‘Žπ‘₯ 𝑀 𝐿 πœƒ, 𝑀 𝐿 πœƒ, 𝑀 =βˆ’ 𝐸 π‘₯ 𝐷 π‘₯ + 𝐸 π‘₯ 𝐷 π‘₯ βˆ’πœ† 𝐸 π‘₯ βˆ‡ π‘₯ 𝐷 𝑀 π‘₯ 2 βˆ’1 2 π‘₯ =𝐺 𝑧 , 𝑧 is sampled from a pre-defined distribution, 𝑝(𝑧). π‘₯ =πœ– π‘₯ + 1βˆ’πœ– π‘₯ Important ingredients in the objective function. It’s not a simple minimization problem. It’s a min-max problem! To compute the βˆ‡ 𝑀 𝐿, recursive differentiation like βˆ‡ 𝑀 βˆ‡ π‘₯ 𝐷 𝑀 is needed.

3 Issues in TrainingInfoProto
It’s designed only for minimization problem, so min-max problems are not supported. It only contains one call (as gradient_binding) to compute gradient, but to train GANs, we need more for βˆ‡ 𝑀 βˆ‡ π‘₯ 𝐷 𝑀 .

4 Ingredients Needed to Support GANs’ Training
Gradient operator which generates gradient tensors. 𝑦, π‘₯β†’ Grad β†’ 𝑑𝑦 𝑑π‘₯ The gradient tensors can be fed into another gradient operator to compute the 2nd-order or high-order differentiation. 𝑑𝑦 𝑑π‘₯ , π‘₯β†’ Grad β†’ 𝑑 2 𝑦 𝑑 π‘₯ 2

5 A Possible Signature of Gradient Operator
Attributes xs: Names of the 𝑛 differentiated tensors. y: Name of target tensor. Inputs Values of the source tensors. The i-th tensor in β€œInputs” would be bound to the i-th tensor in β€œxs”. Outputs [ πœ•π‘¦ πœ• π‘₯ 1 , πœ•π‘¦ πœ• π‘₯ 2 , …, πœ•π‘¦ πœ• π‘₯ 𝑛 ].

6 A Possible Signature of Gradient Operator (Cont’d)
Attributes xs: Names of the 𝑛 differentiated tensors. y: Name of target tensor. Attributes are used to identify a sub-graph. To compute πœ•πΏ πœ• 𝐻 2 , we need xs = [β€œH1”, β€œH2”, β€œH3”] and y = β€œL”. Inputs and outputs will be [H1, H2, H3] and [ πœ•πΏ πœ• 𝐻 1 , πœ•πΏ πœ• 𝐻 2 , πœ•πΏ πœ• 𝐻 3 ], respectively. Note that we cannot only compute πœ•πΏ πœ• 𝐻 2 because L depends on H1, H2, and H3.

7 A Possible Signature of Gradient Operator (Cont’d)
Attributes xs = [β€œH1”, β€œH2”, β€œH3”]. y = β€œL”. Procedure to compute outputs. Copy the identified sub-graph to an isolated space. Bind the indicated inputs to the isolated graph. Conduct a forward pass on the isolated graph. Conduct a backward pass on the isolated graph. Bind the gradient tensors to the output tensors. Deallocate the isolated graph.

8 Another Issue in TrainingInfoProto
To solve the min-max problems in all GANs, one training iteration may contain two isolated assignments. 𝑀←𝑀+ πœ•πΏ(πœƒ,𝑀) πœ•π‘€ πœƒβ†πœƒβˆ’ πœ•πΏ(πœƒ,𝑀) πœ•πœƒ However, we only have one β€œupdate_binding” in the TrainingInfoProto.

9 Possible Changes Related to TrainingInfoProto
Change loss and optimizer to a list of NodeProto (they are currently single NodeProto’s) or merge loss and optimizer into a single list of NodeProto’s. Introduce Update operator or make ModelProto.training_info a list of TrainingInfoProto to capture assignments in different training stages.

10 An Overall Proposal Introduce Gradient Operator Attributes Inputs
xs: Names of the 𝑛 differentiated tensors. y: Name of target tensor. Inputs Values of the source tensors. The i-th tensor in β€œInputs” would be bound to the i-th tensor in β€œxs”. Outputs [ πœ•π‘¦ πœ• π‘₯ 1 , πœ•π‘¦ πœ• π‘₯ 2 , …, πœ•π‘¦ πœ• π‘₯ 𝑛 ]. All outputs are optional. User can use empty string to indicate unnecessary ones.

11 An Overall Proposal (Cont’d)
Merge loss and optimizer NodeProto’s into a function. message TrainingInfoProto { repeated TensorProto initializer = 1; optional FunctionProto algorithm = 2; // Loss, optimizer, etc. repeated StringStringEntryProto update_binding = 3; }

12 An Overall Proposal (Cont’d)
Training algorithm can have multiple stages. message ModelProto { … repeated FunctionProto function = 9; repeated TrainingInfoProto training_info = 21; } Training graph and inference are almost isolated. Inside TrainingInfoProto, we can only access initializers in the inference graph. Conducting one recommended iteration means sequentially running training_info[0], training_info[1], …, training_info[-1] once. TrainingInfoProto οƒ  TrainingStageProto

13 Store GAN with A Training Algorithm
ModelProto.function (type: a list of FunctionProto’s). Functions can be called in ModelProto.graph and ModelProto.training_info. There are two functions; one for generator (denoted by 𝐺 πœƒ ) and the other one for discriminator (denoted by 𝐷 𝑀 ). The parameter πœƒ is stored in ModelProto.graph.initializer. The parameter 𝑀 is stored in ModelProto.training_info.initializer. 𝑀 is training-only. Backward pass is similar to forward pass; for example, gradient can happen in the forward pass. Not to affect inference graph --- add Call-to-Graph

14 Store GAN with A Training Algorithm (Cont’d)
ModelProto.graph ModelProto.graph.initializer is the initializers used in 𝐺 πœƒ . ModelProto.graph.input is the inputs of 𝐺 πœƒ . It only includes one NodeProto which is a call to 𝐺 πœƒ . ModelProto.graph.output produces outputs of 𝐺 πœƒ .

15 Store GAN with A Training Algorithm (Cont’d)
ModelProto.training_info[0] Training algorithm to update 𝐺 πœƒ ’s initializers (aka πœƒ). Given some data points π‘₯, a Random function, and a constant 0<πœ–<1. 𝑧←Random() π‘₯ ←Call<G>(𝑧,πœƒ) // G is a FunctionProto and πœƒ is stored in ModelProto.graph.initializer. π‘₯ β†πœ–π‘₯+ 1βˆ’πœ– π‘₯ 𝐿 1 ←Call<D>( π‘₯ ,𝑀) // D is a FunctionProto and 𝑀 is stored in ModelProto.graph.initializer. 𝐿 2 ←Call<D>(π‘₯,𝑀) 𝐿 3 ←Call<D>( π‘₯ ,𝑀) [ 𝐺 π‘₯ ,?]←Grad( π‘₯ ,𝑀, xs=[" π‘₯ ","𝑀"], y=β€œ 𝐿 3 ”) // β€œ?” means the output is not generated. 𝐺 π‘₯ 2 ← 𝐺 π‘₯ 2 𝐿← 𝐿 1 βˆ’ 𝐿 2 +πœ† 𝐺 π‘₯ 2 βˆ’1 2 [?,?,?, 𝐺 𝑀 ,?]←Grad(inputs=[π‘₯, π‘₯ , π‘₯ ,𝑀,πœƒ], xs=["π‘₯"," π‘₯ "," π‘₯ ","𝑀","πœƒ"], y="πΏβ€œ) 𝑀 𝑛𝑒𝑀 β†π‘€βˆ’ 𝐺 𝑀 The pair ( 𝑀 𝑛𝑒𝑀 ,𝑀) presents in update_binding. Both of \theta and w are stored in the inference graph’s initializer list (a global memory shared by inference and training) because they are used in multiple training stages.

16 Store GAN with A Training Algorithm (Cont’d)
ModelProto.training_info[1] Training algorithm to update 𝐷 𝑀 ’s initializers (aka 𝑀). Given some data points π‘₯, a Random function, and a constant 0<πœ–<1. 𝑧←Random() π‘₯ ←Call<G>(𝑧,πœƒ) // 𝐺 is a FunctionProto and πœƒ is stored in ModelProto.graph.initializer. 𝐿← Call<D>( π‘₯ ,𝑀) // 𝐷 is a FunctionProto and 𝑀 is stored in ModelProto.graph.initializer. [?,𝐺,?]←Grad(𝑧, xs=["𝑧","πœƒ","𝑀"], y="𝐿") πœƒ 𝑛𝑒𝑀 β†πœƒ+𝐺 The pair ( πœƒ 𝑛𝑒𝑀 ,πœƒ) presents in update_binding.

17 Thank you!


Download ppt "ONNX Training Discussion"

Similar presentations


Ads by Google