First commit
[jiton:jiton.git] / tests / virtual_test.ml
1 (*
2         Attempt at a simple rgb darken function :
3         - read pixel from pixels
4         - unpack565 into r, g, b
5         - read alpha into a
6         - mul r, a into r'
7         - mul g, a into g'
8         - mul b, a into b'
9         - pack565 r', g', b' into pixel'
10         - write pixel' into image
11 *)
12
13 open Jiton
14 open Impl_virtual
15 open Virtual
16 open Big_int
17
18 type symbol = {
19         outputer : int (*index in plan*) * int (*index in inputs*) ;
20         mutable inputers : (int (*index in plan*) * int (*index in inputs*)) list ;
21         mutable birth : int ;   (* first step that needs this var. *)
22         mutable death : int ;   (* last step that needs it. >=birth. If = we will not use the value but we need a reg to store a value nonetheless. *)
23         mutable alloc_bank : bank_num ;
24         mutable alloc_reg : reg_num }
25
26 type plan = {
27         impl : op_impl ;
28         perm_regs : int array ;
29         mutable in_regs : int array ;
30         scratch_regs : int array ;
31         out_regs : int array }
32
33 type prealloc_plan = {
34         prealloc_impl : op_impl ;
35         input_names : string array ;
36         output_names : string array }
37
38 (* For this to work, the first step of the plan must be the entry point, taking
39  * no input and outputing all the function parameters as var_regs. *)
40 let reg_allocation preplan loops =
41         (* Build symbol table. *)
42         Printf.printf "Build symbol table.\n" ;
43         let symbols = Hashtbl.create 10 in
44         Array.iteri (fun plan_idx pp ->
45                 let add_inputer input_idx sym_name =
46                         (* sym_name may be an integer constant *)
47                         (* FIXME: Ocaml's integers are not great for constants, word would be better. *)
48                         try ignore (int_of_string sym_name)
49                         with Failure _ ->
50                                 let s = Hashtbl.find symbols sym_name in
51                                 s.inputers <- (plan_idx, input_idx) :: s.inputers ;
52                                 s.death <- plan_idx in
53                 let create_symbol output_idx sym_name =
54                         assert (Hashtbl.find_all symbols sym_name = []) ;
55                         Hashtbl.add symbols sym_name {
56                                 outputer = plan_idx, output_idx ;
57                                 inputers = [] ;
58                                 birth = plan_idx ;
59                                 death = plan_idx ;
60                                 alloc_bank = preplan.(plan_idx).prealloc_impl.out_banks.(output_idx) ;
61                                 alloc_reg = 0 } in
62                 (* For each input names, add me as an inputer. *)
63                 Array.iteri add_inputer pp.input_names ;
64                 (* For each output names, create the symbol. *)
65                 Array.iteri create_symbol pp.output_names) preplan ;
66         
67         (* But we have loops : all vars outputed before the looping
68          * point and inputed after it must be kept untill loop end. *)
69         let check_death sym_name symbol = 
70                 let (outputer_idx, _) = symbol.outputer in
71                 let check_loop (loop_start_idx, loop_end_idx) =
72                         if
73                                 outputer_idx < loop_start_idx &&
74                                 List.exists (fun (inputer_idx, _) -> inputer_idx >= loop_start_idx) symbol.inputers
75                         then (
76                                 Printf.printf "Make register %s immortal in loop [%d -> %d].\n"
77                                         sym_name loop_start_idx loop_end_idx ;
78                                 symbol.death <- loop_end_idx) in
79                 List.iter check_loop loops in
80         Hashtbl.iter check_death symbols ;
81         
82         (* Find out how many registers we need to store vars. *)
83         let nb_var_regs = Array.init (Array.length register_sets) (fun _ ->
84                 Array.make (Array.length preplan) 0) in
85         Hashtbl.iter (fun _sym_name symbol ->
86                 for i = symbol.birth to symbol.death do
87                         nb_var_regs.(symbol.alloc_bank).(i) <- nb_var_regs.(symbol.alloc_bank).(i) + 1
88                 done) symbols ;
89         let nb_vars = Array.init (Array.length register_sets) (fun s ->
90                 Array.fold_left max 0 nb_var_regs.(s)) in
91         Printf.printf "For vars we need :\n" ;
92         Array.iteri (fun bank nb -> Printf.printf "\t%d regs from bank %d\n" nb bank) nb_vars ;
93         
94         (* Find out how many registers we need for perm and scratch regs, per bank. *)
95         let arr_or_zero arr idx =
96                 if idx < Array.length arr then arr.(idx) else 0 in
97         let get_required_perms bank =
98                 Array.fold_left (fun prev pp -> prev + (arr_or_zero pp.prealloc_impl.perm bank)) 0 preplan in
99         let get_required_scratchs bank =
100                 Array.fold_left (fun prev pp -> max prev (arr_or_zero pp.prealloc_impl.scratch bank)) 0 preplan in
101         let nb_banks   = Array.length register_sets in
102         let nb_perms   = Array.init nb_banks get_required_perms in
103         let nb_scratch = Array.init nb_banks get_required_scratchs in
104         for bank = 0 to nb_banks - 1 do
105                 Printf.printf "Bank %d : need %d perms registers, %d scratch registers.\n"
106                         bank nb_perms.(bank) nb_scratch.(bank)
107         done ;
108
109         (* Partition register banks like this : first the permanent registers,
110          * then the scratch registers, then the var registers. *)
111         let first_perm _bank = 0 in
112         let first_scratch bank = (first_perm bank) + nb_perms.(bank) in
113         let first_var bank = (first_scratch bank) + nb_scratch.(bank) in
114
115         (* For allocating var registers we need a bitmap of these registers : *)
116         let var_reg_bitmap = Array.init (Array.length register_sets) (fun s ->
117                 Array.make nb_vars.(s) false) in
118         let alloc_var_reg bank =
119                 let rec aux i =
120                         if var_reg_bitmap.(bank).(i) then aux (i+1)
121                         else (var_reg_bitmap.(bank).(i) <- true ; i) in
122                 aux 0 in
123         let free_var_reg (bank, i) =
124                 var_reg_bitmap.(bank).(i) <- false in
125         let freelist_at_step = Array.make ((Array.length preplan)+1) [] in
126
127         (* Then build the plan *)
128         let next_perm = Array.init nb_banks (fun b -> first_perm b) in
129         let nb_regs reqs = Array.fold_left (+) 0 reqs in
130         let bank_of reqs i =    (* from [|5;4|] and 7, return 1,2 *)
131                 let rec aux bank i =
132                         if i <= reqs.(bank) then bank, i
133                         else aux (bank+1) (i-reqs.(bank)) in
134                 aux 0 i in
135         let alloc_perm_of_bank bank =
136                 let res = next_perm.(bank) in
137                 next_perm.(bank) <- succ next_perm.(bank) ;
138                 res in
139         let init_plan plan_idx =
140                 (* Free the vars that will not be used any more. *)
141                 List.iter free_var_reg freelist_at_step.(plan_idx) ;
142                 (* Build plan *)
143                 let pp = preplan.(plan_idx) in
144                 let impl = pp.prealloc_impl in {
145                         impl = impl ;
146                         perm_regs = Array.init (nb_regs impl.perm)
147                                 (fun i -> alloc_perm_of_bank (fst (bank_of impl.perm i))) ;
148                         scratch_regs = Array.init (nb_regs impl.scratch)
149                                 (fun i ->
150                                         let bank, reg_in_bank = bank_of impl.scratch i in
151                                         (first_scratch bank) + reg_in_bank) ;
152                         out_regs = Array.init (Array.length pp.output_names) (fun i ->
153                                 let sym_name = pp.output_names.(i) in
154                                 let symbol = Hashtbl.find symbols sym_name in
155                                 let free_slot = alloc_var_reg symbol.alloc_bank in
156                                 let bank_num = symbol.alloc_bank in
157                                 let reg_num = free_slot + (first_var bank_num) in
158                                 freelist_at_step.(symbol.death + 1) <-
159                                         (bank_num, free_slot) :: freelist_at_step.(symbol.death + 1) ;
160                                 Printf.printf "Using register %d.%d for %s up to step %d\n" bank_num reg_num sym_name symbol.death ;
161                                 reg_num) ;
162                         (* We cannot initialize in_regs for now, since we need to refer back to
163                          * previously created plan entries. So we create an dummy array here and will finish
164                          * initialization hereafter. *)
165                         in_regs = [||] } in
166         let plan = Array.init (Array.length preplan) init_plan in
167         (* Finish init of in_regs. *)
168         for plan_idx = 0 to (Array.length plan) - 1 do
169                 let pp = preplan.(plan_idx) in
170                 plan.(plan_idx).in_regs <- Array.init (Array.length pp.input_names) (fun i ->
171                         let sym_name = pp.input_names.(i) in
172                         try int_of_string sym_name
173                         with Failure _ ->
174                                 let symbol = Hashtbl.find symbols sym_name in
175                                 let outputer_idx, output_idx = symbol.outputer in
176                                 plan.(outputer_idx).out_regs.(output_idx))
177         done ;
178         plan
179
180 (* From a "user program" giving only the main ingredients, build a preplan
181  * by cooking possible implementations, served with loop and func machinery.
182  *)
183 type user_sym = string * data_type
184 type user_op = impl_lookup * string array * user_sym array
185 type user_plan = user_op array
186 type user_symbol = { mutable bank : int ; size : data_type }
187
188 let string_of_spec_in specs =
189         let string_of_spec = function
190                 | Reg (bank, size) -> Printf.sprintf "Reg (bank=%d, size=%d)" bank size
191                 | Cst size -> Printf.sprintf "Cst size=%d" size
192                 | Auto size -> Printf.sprintf "Auto size=%d" size in
193         (Array.fold_left (fun prefix spec -> prefix^(string_of_spec spec)^"; ") "[ " specs)^" ] "
194 let string_of_spec_out specs =
195         let string_of_spec spec = Printf.sprintf "size=%d" spec in
196         (Array.fold_left (fun prefix spec -> prefix^(string_of_spec spec)^"; ") "[ " specs)^" ] "
197
198 let make_preplan user_plan func_params =
199         let func_param_names = Array.map fst func_params in
200         let func_param_sizes = Array.map snd func_params in
201         (* Returns a path of larger allowed scale.
202          * A "path" is a prealloc_plan for the loop body (ie user_plan). *)
203         let make_path min_scale max_scale =
204                 let find_path scale =
205                         Printf.printf "Looking for path of scale %d.\n" scale ;
206                         (* Build a symbol table giving the expected size and register bank of symbols. *)
207                         let symbols = Hashtbl.create 10 in
208                         (* Symbols are known from function parameters and later output variables. *)
209                         let add_symbol name size bank =
210                                 Printf.printf "Add symbol %s, bank=%d, size=%d.\n" name bank size ;
211                                 Hashtbl.add symbols name { bank = bank ; size = size } in
212                         let add_symbols bank = Array.iter (fun (name, size) -> add_symbol name size bank) in
213                         add_symbols 0 func_params ;
214                         Array.map (fun (impl_lookup, inputs, outputs) ->
215                                 (* Build input and output specifier. *)
216                                 let specs_in = Array.map (fun sym_name ->
217                                         (* Is it a constant ? *)
218                                         try Cst (int_of_string sym_name)
219                                         with Failure _ -> (     (* Or get info from the symbol table. *)
220                                                 let symbol = Hashtbl.find symbols sym_name in
221                                                 Reg (symbol.bank, symbol.size))) inputs in
222                                 let specs_out = Array.map (fun (_, sym_size) -> sym_size) outputs in
223                                 (* If we can't find it, we could still achieve the same result by repeating scale
224                                  * times the implementation for scale=1. *)
225                                 Printf.printf "\tLooking for an impl @scale=%d, for specs = %s -> %s.\n" scale (string_of_spec_in specs_in) (string_of_spec_out specs_out) ;
226                                 let impl = impl_lookup (scale, specs_in, specs_out) in
227                                 Printf.printf "\tfound an impl giving %d outputs.\n" (Array.length impl.out_banks) ;
228                                 (* Add new symbols for outputs. *)
229                                 Array.iteri (fun i (name, size) ->
230                                         add_symbol name size impl.out_banks.(i)) outputs ;
231                                 (* Returns the prealloc_plan. *)
232                                 {       prealloc_impl = impl ;
233                                         input_names = inputs ;
234                                         output_names = Array.map fst outputs }) user_plan in
235                 let rec aux scale =
236                         if scale < min_scale then raise Not_found
237                         else try (find_path scale), scale with Not_found -> (
238                                 Printf.printf "No path for scale %d.\n" scale ;
239                                 aux (scale - 1)) in
240                 aux max_scale in
241         let preplan_of_path path =
242                 (* Merely add entry/exit points and loop head/tail. *)
243                 let preplan = Array.concat [
244                         [|      {       prealloc_impl = load_params (1, [||], func_param_sizes) ;
245                                         input_names = [||] ; output_names = func_param_names } ;
246                                 {       prealloc_impl = loop_head (1, [| Reg (1, 32) |], [||]) ;
247                                         input_names = [| "width" |] ; output_names = [||] } |] ;
248                         path ;
249                         [|      {       prealloc_impl = loop_tail (1, [||], [||]) ;
250                                         input_names = [||] ; output_names = [||] } |] ] in
251                 let loops = [ 1, (Array.length preplan)-1 ] in
252                 preplan, loops in
253         let path_with_renamed_vars suffix path =
254                 let array_exits arr e =
255                         try (
256                                 Array.iter (fun e' -> if e = e' then raise Exit) arr ;
257                                 false
258                         ) with Exit -> true in
259                 let is_constant s = try (ignore (int_of_string s) ; true) with Failure _ -> false in
260                 let renamed arr = Array.map (fun s ->
261                         if array_exits func_param_names s || is_constant s then s else s^suffix) arr in
262                 Array.map (fun p ->
263                         {       p with
264                                 input_names  = (renamed p.input_names) ;
265                                 output_names = (renamed p.output_names) }) path in
266         let combine_paths slow_path fast_path scale =
267                 (* FIXME: we do not take into account data alignment here. but how can we ? *)
268                 let preplan = Array.concat [
269                         [|      {       prealloc_impl = load_params (scale, [||], func_param_sizes) ;
270                                         input_names = [||] ; output_names = func_param_names } ;
271                                 {       prealloc_impl = loop_head (scale, [| Reg (1, 32) |], [||]) ;
272                                         input_names = [| "width" |] ; output_names = [||] } |] ;
273                         path_with_renamed_vars "[fast]" fast_path ;
274                         [|      {       prealloc_impl = loop_tail (scale, [||], [||]) ;
275                                         input_names = [||] ; output_names = [||] } ;
276                                 {       prealloc_impl = loop_head (1, [| Reg (1, 32) |], [||]) ;
277                                         input_names = [| "width" |] ; output_names = [||] } |] ;
278                         path_with_renamed_vars "[slow_finish]" slow_path ;
279                         [|      {       prealloc_impl = loop_tail (1, [||], [||]) ;
280                                         input_names = [||] ; output_names = [||] } |] ] in
281                 let loop_tail_1 = Array.length fast_path + 2 in
282                 let loops = [ 1, loop_tail_1 ; loop_tail_1 + 1, (Array.length preplan)-1 ] in
283                 preplan, loops in
284         (* First, build the slow path (1 item at a time) *)
285         let slow_path, _ = make_path 1 1 in
286         (* Then look for a fast path (N items at a time) *)
287         try (
288                 let fast_path, scale = make_path 2 8 in
289                 (* Then build the final path using both slow and fast paths. *)
290                 combine_paths slow_path fast_path scale
291         ) with Not_found -> (
292                 (* Or using only the slow path if no fast path was found. *)
293                 preplan_of_path slow_path
294         )
295
296 let () =
297         let proc = make_proc 3 in
298         (* User program. *)
299         let user_program = [|
300                 stream_read, [| "pixels" |], [| "pixel", 16 |] ;
301                 unpack565, [| "pixel" |], [| "r", 8 ; "g", 8 ; "b", 8 |] ;
302                 stream_read, [| "alpha" |], [| "a", 8 |] ;
303                 mul_rshift, [| "r" ; "a" ; "8" |], [| "r'", 8 |] ;
304                 mul_rshift, [| "g" ; "a" ; "8" |], [| "g'", 8 |] ;
305                 mul_rshift, [| "b" ; "a" ; "8" |], [| "b'", 8 |] ;
306                 pack565, [| "r'" ; "g'" ; "b'" |], [| "pixel'", 16 |] ;
307                 stream_write, [| "image" ; "pixel'" |], [||] |] in
308         let func_params = [| "width", 32 ; "pixels", 32 ; "alpha", 32 ; "image", 32 |] in
309         let preplan, loops = make_preplan user_program func_params in
310         (* Register allocation. *)
311         let plan = reg_allocation preplan loops in
312         (* Function exec plan. *)
313         (* Emit entry point *)
314         emit_entry_point proc [||] ;
315         (* Emit loop invariants *)
316         Array.iter (fun { impl=impl; perm_regs=perm_regs } -> impl.preamble_emitter proc perm_regs) plan ;
317         (* Emit function body *)
318         Array.iter (fun { impl=impl; in_regs=in_regs; scratch_regs=scratch_regs; out_regs=out_regs} ->
319                 impl.emitter proc in_regs scratch_regs out_regs) plan ;
320         (* Emit exit code *)
321         emit_exit proc ;
322         (* Copy pixels and alpha into memory *)
323         let nb_pixels = 2 in
324         let pixels_addr = 0 in
325         let alpha_addr = pixels_addr + nb_pixels*2 in
326         let image_addr = alpha_addr + nb_pixels*1 in
327         memory_write_32 pixels_addr (big_int_of_int 0x12345678) ;
328         assert (eq_big_int (big_int_of_int 0x12345678) (memory_read_32 pixels_addr)) ;
329         memory_write_16 alpha_addr (big_int_of_int 0x8040) ; (* first pixel is /4 and second is /2 *)
330         (* Run *)
331         exec proc (Array.map word_of_int [|nb_pixels ; pixels_addr ; alpha_addr ; image_addr|]) ;
332         (* Show result *)
333         let result = memory_read_32 image_addr in
334         Printf.printf "Image = %x\n" (int_of_string (string_of_big_int result)) ;
335         assert (eq_big_int result (big_int_of_int 0x90a1186))
336