《机器学习实战》CART回归树源码问题:TypeError: list indices must be integers or slices, not tuple
作者:互联网
书中代码1:
def binSplitDataSet(dataSet, feature, value):
mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:][0]
mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:][0]
return mat0,mat1
改成:
def binSplitDataSet(dataSet, feature, value):
featList = []
mat0 = []
mat1 = []
for featVec in dataSet:
featList.append(featVec[feature])
for feat in featList:
if feat > value:
mat0.append(dataSet[featList.index(feat)])
else:
mat1.append(dataSet[featList.index(feat)])
return mat0, mat1
书中代码2:
def regLeaf(dataSet):
return mean(dataSet[:,-1])
改成:
def regLeaf(dataSet):
valueList = []
for featVec in dataSet:
valueList.append(featVec[-1])
return mean(valueList)
书中代码3:
def regErr(dataSet):
return var(dataSet[:,-1]) * shape(dataSet)[0]
改成:
def regErr(dataSet):
valueList = []
for featVec in dataSet:
valueList.append(featVec[-1])
var = 0
mean = sum(valueList)/len(valueList)
for value in valueList:
var += (mean-value)**2
return var/len(valueList) * shape(dataSet)[0]
书中代码4:
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
tolS = ops[0]; tolN = ops[1]
if len(set(dataSet[:,-1].T.tolist()[0])) == 1:
return None, leafType(dataSet)
m,n = shape(dataSet)
S = errType(dataSet)
bestS = inf; bestIndex = 0; bestValue = 0
for featIndex in range(n-1):
for splitVal in set(dataSet[:,featIndex]):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
if (S - bestS) < tolS:
return None, leafType(dataSet) #exit cond 2
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #exit cond 3
return None, leafType(dataSet)
return bestIndex,bestValue
改成:
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1, 4)):
tolS = ops[0]
tolN = ops[1]
valueList = []
for featVec in dataSet:
valueList.append(featVec[-1])
if len(list(set(valueList))) == 1:
return None, leafType(dataSet)
m, n = shape(dataSet)
S = errType(dataSet)
bestS = inf
bestIndex = 0
bestValue = 0
for featIndex in range(n - 1):
valueList = []
for featVec in dataSet:
valueList.append(featVec[featIndex])
for splitVal in list(set(valueList)):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
if (S - bestS) < tolS:
return None, leafType(dataSet)
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
return None, leafType(dataSet)
return bestIndex, bestValue
运行结果:
标签:integers,TypeError,return,mat0,mat1,dataSet,shape,源码,valueList 来源: https://blog.csdn.net/qq_38757545/article/details/122723358