How to write a scikit-learn compatible estimator/transformer
This is a modal window.
The media could not be loaded, either because the server or network failed or because the format is not supported.
Formal Metadata
Title |
| |
Subtitle |
| |
Title of Series | ||
Number of Parts | 490 | |
Author | ||
License | CC Attribution 2.0 Belgium: You are free to use, adapt and copy, distribute and transmit the work or content in adapted or unchanged form for any legal purpose as long as the work is attributed to the author in the manner specified by the author or licensor. | |
Identifiers | 10.5446/47287 (DOI) | |
Publisher | ||
Release Date | ||
Language |
Content Metadata
Subject Area | ||
Genre | ||
Abstract |
|
FOSDEM 2020418 / 490
4
7
9
10
14
15
16
25
26
29
31
33
34
35
37
40
41
42
43
45
46
47
50
51
52
53
54
58
60
64
65
66
67
70
71
72
74
75
76
77
78
82
83
84
86
89
90
93
94
95
96
98
100
101
105
106
109
110
116
118
123
124
130
135
137
141
142
144
146
151
154
157
159
164
166
167
169
172
174
178
182
184
185
186
187
189
190
191
192
193
194
195
200
202
203
204
205
206
207
208
211
212
214
218
222
225
228
230
232
233
235
236
240
242
244
249
250
251
253
254
258
261
262
266
267
268
271
273
274
275
278
280
281
282
283
284
285
286
288
289
290
291
293
295
296
297
298
301
302
303
305
306
307
310
311
315
317
318
319
328
333
350
353
354
356
359
360
361
370
372
373
374
375
379
380
381
383
385
386
387
388
391
393
394
395
397
398
399
401
409
410
411
414
420
421
422
423
424
425
427
429
430
434
438
439
444
449
450
454
457
458
459
460
461
464
465
466
468
469
470
471
472
480
484
486
487
489
490
00:00
EstimatorSmith chartComputer clusterMaxima and minimaExecution unitBoom (sailing)Machine learningStatisticsGoodness of fitEstimatorSoftware maintenanceComputer animation
00:53
MIDIMach's principleMaxima and minimaCovering spaceRandomizationVector spaceComputer animation
01:12
Annulus (mathematics)Graphics processing unitEstimatorPointer (computer programming)Execution unitPermianCASE <Informatik>Transformation (genetics)Point (geometry)Spacetime1 (number)Set (mathematics)Dependent and independent variablesParameter (computer programming)Library (computing)EstimatorConnectivity (graph theory)Regular graphDifferent (Kate Ryan album)Principal component analysisEndliche Modelltheorie
03:17
Lie groupEstimatorMaxima and minimaExecution unitTransformation (genetics)Combinational logicDefault (computer science)CASE <Informatik>CodeAlgorithmEstimatorSpacetimeFitness functionParameter (computer programming)Library (computing)Multiplication signEquivalence relationSource code
05:06
Menu (computing)EstimatorEstimationHill differential equationAnnulus (mathematics)Statistical hypothesis testingInterior (topology)Computing platformEstimatorPredictabilitySoftware testing1 (number)outputAttribute grammarDependent and independent variablesParameter (computer programming)Validity (statistics)CASE <Informatik>Support vector machineConnectivity (graph theory)Social classFitness functionSource codeComputer animation
08:49
12 (number)Recurrence relationData modelComputer animationSource codeXML
09:10
InformationEstimatorCloningNormed vector spaceReading (process)Computer wormMaxima and minimaWebsiteInterior (topology)Attribute grammarEstimatorMeta elementMereologyoutputData structureLinear regressionFunction (mathematics)Functional (mathematics)Sampling (statistics)Multiplication sign1 (number)Group actionWeightContext awarenessCoefficientUtility softwareSoftware testingStatistical hypothesis testingParameter (computer programming)Point (geometry)GenderFitness functionType theoryFrame problemPredictabilityCategory of beingEndliche ModelltheorieData conversionDefault (computer science)Library (computing)Wrapper (data mining)Computer animationSource codeXML
18:08
Point cloudFacebookOpen source
Transcript: English(auto-generated)
00:08
Okay, thank you everyone. Next talk, Adrin is going to tell us about scikit custom transformers, I'm sorry, it's not mine.
00:20
Okay, thank you. Hello everybody, I'm Adrin, I work at Anaconda and I'm one of the scikit land maintainers. Today, I'm going to talk about how to write your own estimator.
00:40
Before I start, I need to know how many people have used scikit-learn? Okay, cool, so I don't need to focus too much on the background, that's good. For the rest of us who are not familiar with it, it's a statistical machine learning library, that means that it does cover all the old school stuff, the support vector
01:03
machines, random forest, k-means and whatnot. It does not include the deep learning ones, it doesn't cover GPU acceleration, just not in this scope. That said, when we look at the library, what are some of the main components of
01:21
the library before we start writing our own estimator? We need to understand that. We have estimators, the estimators are either transformers, in which case they take some data, they transform that, they spit it out, or they are predictors, they are classifiers or regressors. Then we have scorers, these models we need to know how they perform, so we have different
01:44
scorers to measure their performance in different ways. Then we have meta-estimators. The meta-estimators take an estimator and they do something with it. Two of the important ones and relevant ones to this talk are pipeline, which allows you
02:01
to have a set of transformers and then if you will at the end, a predictor, you have your classifier at the end and then you have your transformers before that. And then it lets you treat that whole pipeline as one single estimator. And then you have gridstitch, which is easier to explain with a little example.
02:22
The usual pipeline, we have our data, we need to preprocess and prepare the data to give that to our classifier in the case of a classification. So in this case I have two steps to prepare the data and then I feed that to an SGD classifier. But each of these steps usually have some hyperparameters that you can tune.
02:44
If it's a transformer that, if it's doing principal component analysis, like how many components do you want to return? That number. Are you doing k-means, that k? If you're regularizing, what is the regularization parameter?
03:01
How do you do that? That's your parameter set and then that set defines your space. And now you want to search in that space and find the best point for your data. Gridstitch does that for you. You pass it your estimator and your parameter space and it does the search.
03:23
If you want to use a different score than the default one, you can also pass that. So with all that flexibility, why would we want to write our own estimator? There are a couple of cases. One is that scikit-learn doesn't have all the algorithms out there. It does have the classical ones, but it's not really possible for us to include everything.
03:44
So if you fancy a new algorithm, it's probably not there. Or if you are a researcher who would like to implement their own and work on their own method, you probably want to write your own and then see how it works in combination with the other methods and transformers out there.
04:00
Or my favorite example is if you're doing ethics. Doing ethics and bias mitigation and detection are not in the scope of scikit-learn. So if you want to work on that, then you would have to write your own scikit-learn compatible ones and mitigate your bias. We also don't include things that are extremely specific to certain use cases.
04:20
If you need to do something that applies only to your data, you probably need to write that and it's not going to be included in the library. Another use case is writing meta-estimators. If you want to do something before or after every time you call an estimator, you can easily write a meta-estimator, wrap around your estimators, and then do logging or auditing or whatnot.
04:42
So what's the basic API? What does it look like? Estimators expose fit to train on the data, predict if they're doing classification or regression, transform if they're a transformer, and score if they're a predictor you need to know how they perform. When I look at people's codes trying to write their own estimators, it looks as if
05:06
they watched this talk or equivalent of this talk and they stopped here. So please don't. If I want to write, this estimator is not really doing anything fancy, it's just
05:21
to show how you could write it. What are the components that you need? Before that, it is a very opinionated API and it has its own design. I know probably half of us in this room may not agree with that design, but that's what it is.
05:41
That's not the discussion. We can talk about it later. We do composition. That means that if you're writing an estimator, you have to have base estimator. If you're writing a classifier, it's classifier mix-in, and then depending on what you do, you would need different mix-ins. Regressor mix-in, meta-assimilator mix-in, and a bunch of other ones.
06:02
We have a bunch of really nice methods to do input validation. You really don't need to write your own input validation. You don't need to check if the input is a non-parity or not. All of that is there. Then my classifier is going to wrap around an SVC in a very poor way.
06:22
Things to note here. I have my init, and in the init I accept my hyperparameters, and the only thing I do is that I store them. I store them in public attributes and I do no validation. That is important. All the validation goes into fit. In fit, I do input validation, and if needed, I do validation on my parameters.
06:47
If two parameters are not compatible and I need to check only one of them is set, this is where I do that. And then here I'm just storing my trained SVC in an estimator with a trailing underscore,
07:02
and that's, again, important. The convention is that attributes are attributes, they're public. If there's a trailing underscore, it is set in fit. If it's a leading underscore, it's private, and backward compatibility is not guaranteed. Then I have predict.
07:20
I check if I'm fitted. If yes, then I check my input, and then I delegate to my estimator's predict. So what did I use there? One of them was check is fitted. It checks if there is any attribute with a trailing underscore.
07:44
You can tune the behavior. Check array is a really long and important one. It does return a NumPy array unless you say you do explicitly want to support sparse arrays, in which case it doesn't convert a sparse array to a dense array. And if it's a Pandas data frame, it converts that to NumPy.
08:04
If you want to, for example, get your feature names from your Pandas data frame, you do that before passing it to check array. And then check x, y does the same thing, plus doing some extra validation on y.
08:22
Now that we have it, how can we be sure that it is now compatible? The compatibility is usually checked through our common tests. We have check estimator. It does a whole bunch of tests, and we recently added this decorator parameterized with checks. You put that on top of your pytest test, and then it does all the tests individually,
08:45
and you can easily check and debug what went wrong. For example, when I was writing this one, I forgot to set the classes attribute, which is needed if you're a classifier, and then it complained, and then I go back and set it.
09:00
Now that we have it, then it's easy. You can use it the way that you would use before. I have a bunch of data. I can fit on my data. I can get my score. I can put that in a pipeline. Here I have a select k best, and then my classifier. And then I can even pass that to a grid search.
09:22
I fit my grid search, and then if I check my best estimator, I see that my classifier, this is the hyperparameter selected by the grid search. So what are some of the conventions? I pretty much mentioned all of them, except the one that the parameters passed to fit.
09:43
The one that you see usually in the existing Scikit-Learn API is sample weights. But you could pass other stuff. You could pass groups. You could do, like, in the context of bias detection and fairness, we usually have
10:00
our protected attributes that are not a part of the data, like gender, zip code, race, all of that. All of that you can pass to fit as a fit parameter. The convention is that everything that you pass as a fit parameter should be sample-aligned. If you have feature attributes, probably don't pass it there. If you have something that you could pass as an init parameter, do that there.
10:21
And it's important because if you do pass things that are sample-aligned to grid search, when it does the folding and the cross-validation, it does slice these parameters, these extra parameters for you, and then pass that with your data to the fit function.
10:45
These are the usual ones, but not all estimators follow all of that, and either other meta-estimators or some of the tests need to know that. So that's why we recently introduced estimator tags. They're still experimental, as in we may change them without prior notice, like they
11:04
don't go through the usual deprecation cycles, but they're pretty useful. You can tell the other meta-estimator or the test what kind of input types you allow. Do you support multi-output, do you accept NaNs? And then if you want to change any other defaults, you can do that by having a more
11:24
tags attribute. So what are we doing now? This is how it works now, but we are adding a bunch of stuff to the API, and they're useful,
11:46
but that means that you would also need to add or change your API a little bit. One of them, the first one that is coming in, which hopefully will be there in the next release, are n features in and n features out. We want to be able to inspect the models and know how many features went in, and
12:03
for a transformer, how many features are coming out. That's the first step, and it helps a lot for us also to clean up the code, but it also helps to understand what's going on in a pipeline. The step after that is that we want to have feature names.
12:22
Usually if I have data which is not just a numerical block, if I have a Pandas data frame with a bunch of feature names and I have a pipeline, I would like to follow in my pipeline how my features are going through. If I have a bunch of transformers, at the end if I have a classifier, I want to know
12:41
what went into my classifier. If I have a linear model and then I want to inspect my coefficients of the linear model, I want to know which feature was it that now it has a high coefficient there. So for that, we would have right now the API allows you to have get feature names that you return the feature names, but it becomes ambiguous.
13:00
Is that the input feature names or is that the output feature names? Sometimes it's not clear to define what it was. So we're deprecating that, and we're going to have feature names in, feature names out. Pretty clear. And that means that if you pass a Pandas data frame, you would extract the feature names for you, and then at the end you will have all of that propagated in the pipeline.
13:28
The next one that I'm really excited about is data properties, sample props, feature props and data props. Sample weights is an example.
13:40
Gender that I mentioned is another example. The issue there is that right now in the pipeline, if I have a pipeline and I want to pass that to a fit, I have to say, OK, pass this one to the fit of that step of the pipeline. And then if I want to pass the same sample weights to another fit, I need to copy that and say, well, also pass that to this one. And if I have a meta estimator, I don't know if the meta estimator should pass
14:04
that through or not. Maybe I have to duplicate that and pass the one that is used by the meta estimator and pass another one that is used by the one handled by the meta estimator. So it's really not clean. The idea here is that I could have a really nice routing and every step, the pipeline,
14:24
every meta estimator would know what needs to be passed to which step and not just fit, also score and predict. If you need to pass other features, other properties to them, then you should be able to pass them. That requires changing the API a little bit, and then there are some prototypes, and
14:41
hopefully they will go forward and we'll have them soon. But that's not all of it. I only showed you a really, really simple one. And for example, here, if I really wanted to write a clean estimator, if I am
15:01
wrapping around an estimator, I may be a classifier, but I'm also a meta estimator, which means that I shouldn't have to care about which hyperparameters are there. The user should be able to pass an estimator and then I should be able to know what the parameters are. You don't have to do that yourself. You can use the meta estimator mixin.
15:21
And all of that you can see in the points that I give here. This one, this documentation covers most of the stuff I talked about. This file, the base.py has everything other than the, except the meta estimator ones. There are other mixins that you probably could use.
15:42
And then the meta estimator one, and then the validation.py has a lot more utility functions that you could use. Thanks. I'll take questions now.
16:07
We have time for lots of questions. So this estimator can work with TensorFlow data structures and other from Keras, like
16:25
has it enter a probability or just for basic data structures. If I have a data in TensorFlow that is in this structure, can I just insert it in this estimator from scikit-learn or it will require some kind of other conversions
16:43
in the way. So if I understand the question, is the question that if you could put, for example, a PyTorch model as an estimator here or something like that. So the idea is that they don't, so the default API of any of those libraries doesn't
17:03
follow this API, but usually what happens is that they have an sklearn wrapper. So you could, I don't remember which one has it where, but for example, if I see pytorch.sklearn, then I know that that's where I can find my scikit-learn compatible estimators. And then those estimators, they wrap around their own estimators, but they expose an
17:24
API, which is compatible here. Therefore you could take that and then plug it in a pipeline here. Okay. So principally it can, it can work. People do that. Okay. Thank you.
17:47
Don't be shy. Raise your hands. If I don't see crazy tire, nothing. Okay. Thank you.