Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dicp][ascend] Enable 70B get_qkv stage dynamic shape. #793

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions dicp/dicp/dynamo_bridge/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ def symint_in_shape(shape):
return False


def not_all_num_shape(shape):
for elem in shape:
if not isinstance(elem, int):
return True
return False


def save_cpu_gm(gm: torch.fx.GraphModule, folder: str):
Path(folder).mkdir(exist_ok=True)
cpu_gm = copy_gm_to_cpu(gm)
Expand Down
79 changes: 25 additions & 54 deletions dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,58 +219,19 @@ def check_tensor(a, b, atol=5e-2, rtol=1e-2):
)
return self.import_code.getvalue()

def operator_in_str(self, st):
for op in ['+', '-', '*', '/']:
if op in st:
return True
return False

def process_sym_name(self, st):
# dynamic shape feature
if st.isdigit():
return st
elif '+' in st:
sp = st.split('+')
if len(sp) > 2:
sp = [sp[0], '+'.join(sp[1:])]
assert (len(sp) == 2)
sp = [elem.strip() for elem in sp]
if sp[0].isdigit():
(sp[1], sp[0]) = (sp[0], sp[1])
if sp[0] in self.sym_in_args:
arg, idx = self.sym_in_args[sp[0]]
return "{}.shape[{}]".format(arg, idx) + '+' + sp[1]
if sp[0] in self.sym_to_inputs.keys():
return self.sym_to_inputs[sp[0]] + '+' + sp[1]
else:
return self.process_sym_name(sp[0]) + '+' + sp[1]
elif '-' in st:
sp = st.split('-')
if len(sp) > 2:
sp = [sp[0], '-'.join(sp[1:])]
assert (len(sp) == 2)
sp = [elem.strip() for elem in sp]
if sp[0] in self.sym_in_args:
arg, idx = self.sym_in_args[sp[0]]
return "{}.shape[{}]".format(arg, idx) + '-' + sp[1]
if sp[0] in self.sym_to_inputs.keys():
return self.sym_to_inputs[sp[0]] + '-' + sp[1]
else:
return self.process_sym_name(sp[0]) + '-' + sp[1]
elif '*' in st:
sp = st.split('*')
if len(sp) > 2:
sp = [sp[0], '*'.join(sp[1:])]
assert (len(sp) == 2)
sp = [elem.strip() for elem in sp]
if sp[0].isdigit():
(sp[1], sp[0]) = (sp[0], sp[1])
if sp[0] in self.sym_in_args:
arg, idx = self.sym_in_args[sp[0]]
return "{}.shape[{}]".format(arg, idx) + '*' + sp[1]
if sp[0] in self.sym_to_inputs.keys():
return self.sym_to_inputs[sp[0]] + '*' + sp[1]
else:
return self.process_sym_name(sp[0]) + '*' + sp[1]
else:
if st in self.sym_in_args:
arg, idx = self.sym_in_args[st]
return "{}.shape[{}]".format(arg, idx)
return self.sym_to_inputs[st]
# return string wrapper in new version
# node.str() will not fallback SymInt value form
if isinstance(st, torch.SymInt):
return st.node.str()
return str(st)

def gen_call_func(self):
# TODO check scalar input
Expand All @@ -283,6 +244,16 @@ def gen_call_func(self):
args = ['_' if arg not in shape_symint and arg not in self.sym_to_inputs.values() else arg for arg in self.args]
call_body.writeline(f"({','.join(args)}) = args")

# assign SymInt to InputArgs relationship
if len(self.sym_in_args) > 0:
for key in self.sym_in_args.keys():
if not key.isdigit() and not self.operator_in_str(key):
call_body.writeline(f"{key} = {self.sym_in_args[key][0]}.shape[{self.sym_in_args[key][1]}]")
if len(self.sym_to_inputs) > 0:
for key in self.sym_to_inputs.keys():
if not key.isdigit() and not self.operator_in_str(key):
call_body.writeline(f"{key} = {self.sym_to_inputs[key]}")

# generate input dims
if len(self.dynamic_inputs) > 0:
dim_len = 0
Expand Down Expand Up @@ -315,20 +286,20 @@ def gen_call_func(self):
shape = list(elem.shape)
if len(shape) == 0:
raise RuntimeError("Error handling empty output_shape")
shape = [self.process_sym_name(str(dim)) for dim in shape]
shape = [self.process_sym_name(dim) for dim in shape]
shape_str += "[" + ','.join(map(str, shape)) + "],"

# process output_shape with modified args
for elem in self.assign_args:
shape = list(self.input_args[elem[1]].meta['val'].shape)
if len(shape) == 0:
raise RuntimeError("Error handling empty output_shape")
shape = [self.process_sym_name(str(dim)) for dim in shape]
shape = [self.process_sym_name(dim) for dim in shape]
shape_str += "[" + ','.join(map(str, shape)) + "],"
stride = list(self.input_args[elem[1]].meta['val'].stride())
if len(stride) == 0:
raise RuntimeError("Error handling empty output_stride")
stride = [self.process_sym_name(str(dim)) for dim in stride]
stride = [self.process_sym_name(dim) for dim in stride]
extra_stride_str += '[' + ','.join(map(str, stride)) + '],'
extra_storage_offset_str += str(self.input_args[elem[1]].meta['val'].storage_offset()) + ','
shape_str = shape_str[:-1] + ''']'''
Expand All @@ -351,7 +322,7 @@ def gen_call_func(self):
out_storage_offsets.append('0')
continue
stride = list(elem.stride())
stride = [self.process_sym_name(str(dim)) for dim in stride]
stride = [self.process_sym_name(dim) for dim in stride]
out_strides.append(str(stride))
out_storage_offsets.append(elem.storage_offset())
call_body.writeline(f'out_stride = {out_strides}')
Expand Down
179 changes: 145 additions & 34 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.fx.immutable_collections import immutable_list
from torch._subclasses import FakeTensor
import dicp.vendor.AscendGraph.ascend_op as ascend_op
from dicp.dynamo_bridge.utils import symint_in_shape
from dicp.dynamo_bridge.utils import symint_in_shape, not_all_num_shape
from dicp.vendor.AscendGraph.codegen.utils import (
get_ascend_dtype
)
Expand Down Expand Up @@ -78,65 +78,170 @@ def generate_digits_op(shapes):
ascend_op.Const, (shapes, torch.int32, [len(shapes)]))
x_names.append(const_op)

def generate_sym_int(elem):
elem = elem.node.str()
elems = elem.strip().split(' ')
def find_root_num(set_num, num):
while set_num[num] != num:
num = set_num[num]
return num

def merge_disjoint_set(set_num, idx_a, idx_b):
root_a = find_root_num(set_num, idx_a)
root_b = find_root_num(set_num, idx_b)
# an example for (s5 / 8) - (s5 / 16)
# num: 0 1 2 3
# step1 - > set_num: 0 1 2 3
# step2 - > set_num: 0 0 2 2
# step3 - > set_num: 0 0 0 0

# return merged set from root_b to root_a
return [root_a if find_root_num(set_num, s) == root_b else s for s in set_num]

def replace_elem_proxy(elem_str):
# exit if already a proxy
if isinstance(elem_str, torch.fx.proxy.Proxy):
return elem_str
assert not elem_str in ['+', '-', '*', '//', '(', ')']

# handle with integer
if elem_str.isdigit():
const_op = self.get_proxy(
ascend_op.Const, ([int(elem_str)], torch.int32, [1]))
return const_op

arg = None
# dynamic shape feature
if elems[0] in self.sym_in_args:
arg, idx = self.sym_in_args[elems[0]]
# handle if elem in shape of InputArgs
if elem_str in self.sym_in_args:
arg, idx = self.sym_in_args[elem_str]
shape = self.get_proxy(ascend_op.Shape, (arg,))
axis = self.get_proxy(
ascend_op.Const, ([0], torch.int32, [1]))
indice = self.get_proxy(
ascend_op.Const, ([idx], torch.int32, [1]))
gather = self.get_proxy(
ascend_op.GatherV2, (shape, indice, axis))
return gather

# handle if SymInt InputArg needed
return self.sym_to_inputs[elem_str]

def generate_not_num(elem):
# situation for NodeProxy
if isinstance(elem, torch.fx.proxy.Proxy):
x_names.append(elem)
return

elem_str = elem.node.str()
elem_str = elem_str.replace('+', ' + ')
elem_str = elem_str.replace('-', ' - ')
elem_str = elem_str.replace('*', ' * ')
elem_str = elem_str.replace('//', ' // ')
elem_str = elem_str.replace('(', ' ( ')
elem_str = elem_str.replace(')', ' ) ')
elems = elem_str.split(' ')
elems = [e for e in elems if e != '']

# dynamic shape feature
if len(elems) > 1:
assert len(elems) == 3
assert elems[2].isdigit()
assert elems[1] == '+' or elems[1] == '-'
const_op = self.get_proxy(
ascend_op.Const, ([int(elems[2])], torch.int32, [1]))
if arg is not None:
args = (gather, const_op)
else:
args = (self.sym_to_inputs[elems[0]], const_op)
if elems[1] == '+':
x_names.append(self.get_proxy(ascend_op.Add, args))
else:
x_names.append(self.get_proxy(ascend_op.Sub, args))
set_num = []
priority = []
nest = 0

# calculate priority for each operator
# set initial set number
for idx, e in enumerate(elems):
if e == '+' or e =='-':
priority.append(nest * 3 + 0)
elif e == '*' or e == '//':
priority.append(nest * 3 + 1)
else:
if e == '(':
nest += 1
elif e == ')':
nest -= 1
priority.append(-1)

# init set number
if not e in ['+', '-', '*', '//', '(', ')']:
set_num.append(idx)
else:
set_num.append(-1)

# start merge disjoint-set
if len(set_num) > 1:
while len(set(set_num)) > 2:
# seek the highest priority operator
max = -1
m_idx = -1
for idx, prio in enumerate(priority):
if prio > max:
max = prio
m_idx = idx

# merge the highest priority two elements calculation
# find left & right element
left_idx = m_idx - 1
while left_idx > 0 and str(elems[left_idx]) in ['(', ')']:
left_idx -= 1
right_idx = m_idx + 1
while right_idx < len(elems) - 1 and str(elems[right_idx]) in ['(', ')']:
right_idx += 1
left_idx = find_root_num(set_num, set_num[left_idx])
right_idx = find_root_num(set_num, set_num[right_idx])
left_elem = replace_elem_proxy(elems[left_idx])
right_elem = replace_elem_proxy(elems[right_idx])

# generate calculation operator
if elems[m_idx] == '+':
elems[left_idx] = self.get_proxy(ascend_op.Add, (left_elem, right_elem))
elif elems[m_idx] == '-':
elems[left_idx] = self.get_proxy(ascend_op.Sub, (left_elem, right_elem))
elif elems[m_idx] == '*':
elems[left_idx] = self.get_proxy(ascend_op.Mul, (left_elem, right_elem))
else:
elems[left_idx] = self.get_proxy(ascend_op.Div, (left_elem, right_elem))

# merge set number and priority
set_num = merge_disjoint_set(set_num, left_idx, right_idx)
priority[m_idx] = -1

# add final element proxy
final_idx = 0
while final_idx < len(elems) - 1 and str(elems[final_idx]) in ['(', ')']:
final_idx += 1
final_elem = replace_elem_proxy(elems[final_idx])
x_names.append(final_elem)
else:
if arg is not None:
x_names.append(gather)
else:
x_names.append(self.sym_to_inputs[elems[0]])
# only one not num element
node = replace_elem_proxy(elems[0])
x_names.append(node)

dims = []
for elem in shape:
if not isinstance(elem, torch.SymInt):
# process number
if isinstance(elem, int):
dims.append(elem)
continue
st = elem.node.str()
st = str(elem)
if st.isdigit():
dims.append(int(st))
continue

# add number block
if len(dims) > 0:
generate_digits_op(dims)
dims = []
generate_sym_int(elem)
generate_not_num(elem)

# add last number block
if len(dims) > 0:
generate_digits_op(dims)

# concat all ops
return self.get_proxy(ascend_op.ConcatD, (x_names, 0))

def get_shape_proxy(self, shape, dtype=torch.int32):
if isinstance(shape, torch.fx.proxy.Proxy) or isinstance(shape, FakeTensor):
return shape
elif isinstance(shape, list) and symint_in_shape(shape):
elif isinstance(shape, list) and not_all_num_shape(shape):
# include both SymInt & NodeProxy
return self.process_dynamic_shape(shape)
else:
return self.get_proxy(
Expand Down Expand Up @@ -306,12 +411,16 @@ def inge(self, x, y):
y = self.get_const_proxy(y, torch.int32)
return self.get_proxy(ascend_op.GreaterEqual, (x, y))

@register_conversion(aten.div)
@register_conversion([aten.div, _operator.floordiv])
def div(self, x, y):
if isinstance(y, torch.fx.proxy.Proxy):
return self.get_proxy(ascend_op.DivNoNan, (x, y))
assert y != 0
out_dtype = fx_traceback.get_current_meta()['val'].dtype
out = fx_traceback.get_current_meta()['val']
if not isinstance(out, torch.SymInt):
out_dtype = out.dtype
else:
out_dtype = torch.int32
y_op = self.get_const_proxy(y, out_dtype)
return self.get_proxy(ascend_op.Div, (x, y_op), {})

Expand All @@ -332,10 +441,12 @@ def slice(self, x, dim=0, start=None, end=None, step=1):
y_shape = list(fx_traceback.get_current_meta()['val'].shape)
# y_shape = fx_traceback.get_current_meta()['val'].shape
dim = int(dim)
start = int(start) if start is not None else 0
start = start if start >= 0 else x_shape[dim] + start
if not isinstance(start, torch.fx.proxy.Proxy):
start = int(start) if start is not None else 0
start = start if start >= 0 else x_shape[dim] + start
assert start is None or start >= 0 and start < x_shape[dim]

assert dim == -1 or dim >= 0 and dim < len(x_shape)
assert start is None or start >= 0 and start < x_shape[dim]
offset = [0] * len(x_shape)
offset[dim] = start
# import pdb; pdb.set_trace()
Expand Down
Loading